diff --git a/agentfile/__init__.py b/agentfile/__init__.py index 36148efd..b63f1154 100644 --- a/agentfile/__init__.py +++ b/agentfile/__init__.py @@ -1,5 +1,5 @@ -from agentfile.control_plane import FastAPIControlPlane -from agentfile.launchers import LocalLauncher +from agentfile.control_plane import ControlPlaneServer +from agentfile.launchers import LocalLauncher, ServerLauncher from agentfile.message_queues import SimpleMessageQueue from agentfile.orchestrators import ( AgentOrchestrator, @@ -19,8 +19,9 @@ "SimpleMessageQueue", # launchers "LocalLauncher", + "ServerLauncher", # control planes - "FastAPIControlPlane", + "ControlPlaneServer", # orchestrators "AgentOrchestrator", "PipelineOrchestrator", diff --git a/agentfile/app/__init__.py b/agentfile/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentfile/app/app.py b/agentfile/app/app.py new file mode 100644 index 00000000..7832151c --- /dev/null +++ b/agentfile/app/app.py @@ -0,0 +1,207 @@ +import httpx +import logging +import pprint +from typing import Any, Optional + +from textual.app import App, ComposeResult +from textual.containers import VerticalScroll +from textual.reactive import reactive +from textual.widgets import Button, Header, Footer, Static, Input + +from agentfile.app.components.human_list import HumanTaskList +from agentfile.app.components.service_list import ServicesList +from agentfile.app.components.task_list import TasksList +from agentfile.app.components.types import ButtonType +from agentfile.types import TaskDefinition + + +class LlamaAgentsMonitor(App): + CSS = """ + Screen { + layout: grid; + grid-size: 2; + grid-columns: 1fr 2fr; + padding: 0; + } + + #left-panel { + width: 100%; + height: 100%; + } + + #right-panel { + width: 100%; + height: 100%; + } + + .section { + background: $panel; + padding: 1; + margin-bottom: 0; + } + + #tasks { + height: auto; + max-height: 50%; + } + + #services { + height: auto; + max-height: 50%; + } + + VerticalScroll { + height: auto; + max-height: 100%; + border: solid $primary; + margin-bottom: 1; + } + + #right-panel VerticalScroll { + max-height: 100%; + } + + Button { + width: 100%; + margin-bottom: 1; + } + + #details { + background: $boost; + padding: 1; + text-align: left; + } + + #new-task { + dock: bottom; + margin-bottom: 1; + width: 100%; + } + """ + + details = reactive("") + selected_service_type = reactive("") + selected_service_url = reactive("") + + def __init__(self, control_plane_url: str, **kwargs: Any): + self.control_plane_url = control_plane_url + super().__init__(**kwargs) + + def compose(self) -> ComposeResult: + yield Header() + with Static(id="left-panel"): + yield ServicesList(id="services", control_plane_url=self.control_plane_url) + yield TasksList(id="tasks", control_plane_url=self.control_plane_url) + with VerticalScroll(id="right-panel"): + yield Static("Task or service details", id="details") + yield Input(placeholder="Enter: New task", id="new-task") + yield Footer() + + async def on_mount(self) -> None: + self.set_interval(5, self.refresh_details) + + async def watch_details(self, new_details: str) -> None: + if not new_details: + return + + selected_type = ButtonType(new_details.split(":")[0].strip()) + + if selected_type == ButtonType.SERVICE: + self.query_one("#details").update(new_details) + elif selected_type == ButtonType.TASK: + self.query_one("#details").update(new_details) + + async def watch_selected_service_type(self, new_service_type: str) -> None: + if not new_service_type: + return + + if new_service_type == "human_service": + await self.query_one("#right-panel").mount( + HumanTaskList(self.selected_service_url), after=0 + ) + else: + try: + await self.query_one(HumanTaskList).remove() + except Exception: + # not mounted yet + pass + + async def refresh_details( + self, + button_type: Optional[ButtonType] = None, + selected_label: Optional[str] = None, + ) -> None: + if not self.details and button_type is None and selected_label is None: + return + + selected_type = button_type or ButtonType(self.details.split(":")[0].strip()) + selected_label = ( + selected_label or self.details.split(":")[1].split("\n")[0].strip() + ) + + if selected_type == ButtonType.SERVICE: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.control_plane_url}/services/{selected_label}" + ) + service_def = response.json() + + service_dict = service_def + service_url = "" + if service_def.get("host") and service_def.get("port"): + service_url = f"http://{service_def['host']}:{service_def['port']}" + response = await client.get(f"{service_url}/") + service_dict = response.json() + + # format the service details nicely + service_string = pprint.pformat(service_dict) + + self.details = ( + f"{selected_type.value}: {selected_label}\n\n{service_string}" + ) + + self.selected_service_url = service_url + self.selected_service_type = service_dict.get("type") + elif selected_type == ButtonType.TASK: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.control_plane_url}/tasks/{selected_label}" + ) + task_dict = response.json() + + # flatten the TaskResult object + if task_dict["state"].get("result"): + task_dict["state"]["result"] = task_dict["state"]["result"]["result"] + + # format the task details nicely + task_string = pprint.pformat(task_dict) + + self.details = f"{selected_type.value}: {selected_label}\n\n{task_string}" + self.selected_service_type = "" + self.selected_service_url = "" + + async def on_button_pressed(self, event: Button.Pressed) -> None: + # Update the details panel with the selected item + await self.refresh_details( + button_type=event.button.type, selected_label=event.button.label + ) + + async def on_input_submitted(self, event: Input.Submitted) -> None: + new_task = TaskDefinition(input=event.value).model_dump() + async with httpx.AsyncClient() as client: + await client.post(f"{self.control_plane_url}/tasks", json=new_task) + + # clear the input + self.query_one("#new-task").value = "" + + +def run(control_plane_url: str = "http://127.0.0.1:8000") -> None: + # remove info logging for httpx + logging.getLogger("httpx").setLevel(logging.WARNING) + + app = LlamaAgentsMonitor(control_plane_url=control_plane_url) + app.run() + + +if __name__ == "__main__": + run() diff --git a/agentfile/app/components/__init__.py b/agentfile/app/components/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentfile/app/components/human_list.py b/agentfile/app/components/human_list.py new file mode 100644 index 00000000..938a3e1f --- /dev/null +++ b/agentfile/app/components/human_list.py @@ -0,0 +1,96 @@ +import httpx +from typing import Any, List + +from textual.app import ComposeResult +from textual.containers import VerticalScroll, Container +from textual.reactive import reactive +from textual.widgets import Button, Static, Input + +from agentfile.app.components.types import ButtonType +from agentfile.types import HumanResponse, TaskDefinition + + +class HumanTaskButton(Button): + type: ButtonType = ButtonType.HUMAN + task_id: str = "" + + +class HumanTaskList(Static): + tasks: List[TaskDefinition] = reactive([]) + selected_task: str = reactive("") + + def __init__(self, human_service_url: str, **kwargs: Any): + self.human_service_url = human_service_url + super().__init__(**kwargs) + + def compose(self) -> ComposeResult: + with VerticalScroll(id="human-tasks-scroll"): + for task in self.tasks: + button = HumanTaskButton(task.input) + button.task_id = task.task_id + yield button + + async def on_mount(self) -> None: + self.set_interval(2, self.refresh_tasks) + + async def refresh_tasks(self) -> None: + async with httpx.AsyncClient() as client: + response = await client.get(f"{self.human_service_url}/tasks") + tasks = response.json() + + new_tasks = [] + for task in tasks: + new_tasks.append(TaskDefinition(**task)) + + self.tasks = [*new_tasks] + + async def watch_tasks(self, new_tasks: List[TaskDefinition]) -> None: + try: + tasks_scroll = self.query_one("#human-tasks-scroll") + await tasks_scroll.remove_children() + for task in new_tasks: + button = HumanTaskButton(task.input) + button.task_id = task.task_id + await tasks_scroll.mount(button) + except Exception: + pass + + async def watch_selected_task(self, new_task: str) -> None: + if not new_task: + return + + try: + await self.query_one("#respond").remove() + except Exception: + # not mounted yet + pass + + container = Container( + Static(f"Task: {new_task}"), + Input( + placeholder="Type your response here", + ), + id="respond", + ) + + # mount the container + await self.mount(container) + + def on_button_pressed(self, event: Button.Pressed) -> None: + # Update the details panel with the selected item + self.selected_task = event.button.label + + async def on_input_submitted(self, event: Input.Submitted) -> None: + response = HumanResponse(result=event.value).model_dump() + async with httpx.AsyncClient() as client: + await client.post( + f"{self.human_service_url}/tasks/{self.selected_task}/handle", + json=response, + ) + + # remove the input container + await self.query_one("#respond").remove() + + # remove the task from the list + new_tasks = [task for task in self.tasks if task.task_id != self.selected_task] + self.tasks = [*new_tasks] diff --git a/agentfile/app/components/service_list.py b/agentfile/app/components/service_list.py new file mode 100644 index 00000000..ac699d86 --- /dev/null +++ b/agentfile/app/components/service_list.py @@ -0,0 +1,49 @@ +import httpx +from typing import Any, List + +from textual.app import ComposeResult +from textual.containers import VerticalScroll +from textual.reactive import reactive +from textual.widgets import Button, Static + +from agentfile.app.components.types import ButtonType + + +class ServiceButton(Button): + type: ButtonType = ButtonType.SERVICE + + +class ServicesList(Static): + services: List[str] = reactive([]) + + def __init__(self, control_plane_url: str, **kwargs: Any): + self.control_plane_url = control_plane_url + super().__init__(**kwargs) + + def compose(self) -> ComposeResult: + with VerticalScroll(id="services-scroll"): + for service in self.services: + yield ServiceButton(service) + + async def on_mount(self) -> None: + self.set_interval(2, self.refresh_services) + + async def refresh_services(self) -> None: + async with httpx.AsyncClient() as client: + response = await client.get(f"{self.control_plane_url}/services") + services_dict = response.json() + + new_services = [] + for service_name in services_dict: + new_services.append(service_name) + + self.services = [*new_services] + + async def watch_services(self, new_services: List[str]) -> None: + try: + services_scroll = self.query_one("#services-scroll") + await services_scroll.remove_children() + for service in new_services: + await services_scroll.mount(ServiceButton(service)) + except Exception: + pass diff --git a/agentfile/app/components/task_list.py b/agentfile/app/components/task_list.py new file mode 100644 index 00000000..228cae22 --- /dev/null +++ b/agentfile/app/components/task_list.py @@ -0,0 +1,49 @@ +import httpx +from typing import Any, List + +from textual.app import ComposeResult +from textual.containers import VerticalScroll +from textual.reactive import reactive +from textual.widgets import Button, Static + +from agentfile.app.components.types import ButtonType + + +class TaskButton(Button): + type: ButtonType = ButtonType.TASK + + +class TasksList(Static): + tasks: List[str] = reactive([]) + + def __init__(self, control_plane_url: str, **kwargs: Any): + self.control_plane_url = control_plane_url + super().__init__(**kwargs) + + def compose(self) -> ComposeResult: + with VerticalScroll(id="tasks-scroll"): + for task in self.tasks: + yield TaskButton(task) + + async def on_mount(self) -> None: + self.set_interval(5, self.refresh_tasks) + + async def refresh_tasks(self) -> None: + async with httpx.AsyncClient() as client: + response = await client.get(f"{self.control_plane_url}/tasks") + tasks_dict = response.json() + + new_tasks = [] + for task_id in tasks_dict: + new_tasks.append(task_id) + + self.tasks = [*new_tasks] + + async def watch_tasks(self, new_tasks: List[str]) -> None: + try: + tasks_scroll = self.query_one("#tasks-scroll") + await tasks_scroll.remove_children() + for task in new_tasks: + await tasks_scroll.mount(TaskButton(task)) + except Exception: + pass diff --git a/agentfile/app/components/types.py b/agentfile/app/components/types.py new file mode 100644 index 00000000..1f1d8b2c --- /dev/null +++ b/agentfile/app/components/types.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class ButtonType(str, Enum): + SERVICE = "Service" + TASK = "Task" + HUMAN = "Human" diff --git a/agentfile/cli/__init__.py b/agentfile/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentfile/cli/command_line.py b/agentfile/cli/command_line.py new file mode 100644 index 00000000..38921bd5 --- /dev/null +++ b/agentfile/cli/command_line.py @@ -0,0 +1,28 @@ +import argparse + +from agentfile.app.app import run as launch_monitor + + +def main() -> None: + parser = argparse.ArgumentParser(description="llama-agents CLI interface.") + + # Subparsers for the main commands + subparsers = parser.add_subparsers(title="commands", dest="command", required=True) + + # Subparser for the monitor command + monitor_parser = subparsers.add_parser("monitor", help="Monitor the agents.") + monitor_parser.add_argument( + "--control-plane-url", + default="http://127.0.0.1:8000", + help="The URL of the control plane. Defaults to http://127.0.0.1:8000", + ) + monitor_parser.set_defaults( + func=lambda args: launch_monitor(args.control_plane_url) + ) + + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/agentfile/control_plane/__init__.py b/agentfile/control_plane/__init__.py index 5f6e3c6a..91c7e669 100644 --- a/agentfile/control_plane/__init__.py +++ b/agentfile/control_plane/__init__.py @@ -1,4 +1,4 @@ from agentfile.control_plane.base import BaseControlPlane -from agentfile.control_plane.fastapi import FastAPIControlPlane +from agentfile.control_plane.server import ControlPlaneServer -__all__ = ["BaseControlPlane", "FastAPIControlPlane"] +__all__ = ["BaseControlPlane", "ControlPlaneServer"] diff --git a/agentfile/control_plane/base.py b/agentfile/control_plane/base.py index 47f1dd75..a352690b 100644 --- a/agentfile/control_plane/base.py +++ b/agentfile/control_plane/base.py @@ -19,7 +19,7 @@ class BaseControlPlane(MessageQueuePublisherMixin, ABC): @abstractmethod - def as_consumer(self) -> BaseMessageQueueConsumer: + def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: """ Get the consumer for the message queue. @@ -93,3 +93,10 @@ async def get_all_tasks(self) -> dict: :return: All tasks. """ ... + + @abstractmethod + async def launch_server(self) -> None: + """ + Launch the control plane server. + """ + ... diff --git a/agentfile/control_plane/fastapi.py b/agentfile/control_plane/server.py similarity index 72% rename from agentfile/control_plane/fastapi.py rename to agentfile/control_plane/server.py index 390ab6b7..b31d971a 100644 --- a/agentfile/control_plane/fastapi.py +++ b/agentfile/control_plane/server.py @@ -1,7 +1,7 @@ import uuid import uvicorn from fastapi import FastAPI -from typing import Any, Callable, Dict, List, Optional +from typing import Dict, List, Optional from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.objects import ObjectIndex, SimpleObjectNodeMapping @@ -11,6 +11,8 @@ from agentfile.control_plane.base import BaseControlPlane from agentfile.message_consumers.base import BaseMessageQueueConsumer +from agentfile.message_consumers.callable import CallableMessageConsumer +from agentfile.message_consumers.remote import RemoteMessageConsumer from agentfile.message_queues.base import BaseMessageQueue, PublishCallback from agentfile.messages.base import QueueMessage from agentfile.orchestrators.base import BaseOrchestrator @@ -29,22 +31,7 @@ logging.basicConfig(level=logging.INFO) -class ControlPlaneMessageConsumer(BaseMessageQueueConsumer): - message_handler: Dict[str, Callable] - message_type: str = "control_plane" - - async def _process_message(self, message: QueueMessage, **kwargs: Any) -> None: - action = message.action - if action not in self.message_handler: - raise ValueError(f"Action {action} not supported by control plane") - - if action == ActionTypes.NEW_TASK and message.data is not None: - await self.message_handler[action](TaskDefinition(**message.data)) - elif action == ActionTypes.COMPLETED_TASK and message.data is not None: - await self.message_handler[action](TaskResult(**message.data)) - - -class FastAPIControlPlane(BaseControlPlane): +class ControlPlaneServer(BaseControlPlane): def __init__( self, message_queue: BaseMessageQueue, @@ -56,6 +43,8 @@ def __init__( tasks_store_key: str = "tasks", step_interval: float = 0.1, services_retrieval_threshold: int = 5, + host: str = "127.0.0.1", + port: int = 8000, running: bool = True, ) -> None: self.orchestrator = orchestrator @@ -68,6 +57,8 @@ def __init__( ) self.step_interval = step_interval self.running = running + self.host = host + self.port = port self.state_store = state_store or SimpleKVStore() @@ -85,6 +76,12 @@ def __init__( self.app = FastAPI() self.app.add_api_route("/", self.home, methods=["GET"], tags=["Control Plane"]) + self.app.add_api_route( + "/process_message", + self.process_message, + methods=["POST"], + tags=["Control Plane"], + ) self.app.add_api_route( "/services/register", @@ -98,10 +95,25 @@ def __init__( methods=["POST"], tags=["Services"], ) + self.app.add_api_route( + "/services/{service_name}", + self.get_service, + methods=["GET"], + tags=["Services"], + ) + self.app.add_api_route( + "/services", + self.get_all_services, + methods=["GET"], + tags=["Services"], + ) self.app.add_api_route( "/tasks", self.create_task, methods=["POST"], tags=["Tasks"] ) + self.app.add_api_route( + "/tasks", self.get_all_tasks, methods=["GET"], tags=["Tasks"] + ) self.app.add_api_route( "/tasks/{task_id}", self.get_task_state, methods=["GET"], tags=["Tasks"] ) @@ -118,16 +130,39 @@ def publisher_id(self) -> str: def publish_callback(self) -> Optional[PublishCallback]: return self._publish_callback - def as_consumer(self) -> BaseMessageQueueConsumer: - return ControlPlaneMessageConsumer( - message_handler={ - ActionTypes.NEW_TASK: self.create_task, - ActionTypes.COMPLETED_TASK: self.handle_service_completion, - } + async def process_message(self, message: QueueMessage) -> None: + action = message.action + + if action == ActionTypes.NEW_TASK and message.data is not None: + await self.create_task(TaskDefinition(**message.data)) + elif action == ActionTypes.COMPLETED_TASK and message.data is not None: + await self.handle_service_completion(TaskResult(**message.data)) + else: + raise ValueError(f"Action {action} not supported by control plane") + + def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: + if remote: + return RemoteMessageConsumer( + url=f"http://{self.host}:{self.port}/process_message", + message_type="control_plane", + ) + + return CallableMessageConsumer( + message_type="control_plane", + handler=self.process_message, ) - def launch(self) -> None: - uvicorn.run(self.app) + async def launch_server(self) -> None: + logger.info(f"Launching control plane server at {self.host}:{self.port}") + # uvicorn.run(self.app, host=self.host, port=self.port) + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self) -> None: + pass + + cfg = uvicorn.Config(self.app, host=self.host, port=self.port) + server = CustomServer(cfg) + await server.serve() async def home(self) -> Dict[str, str]: return { @@ -167,6 +202,24 @@ async def deregister_service(self, service_name: str) -> None: self._total_services -= 1 # TODO: object index does not have delete yet + async def get_service(self, service_name: str) -> ServiceDefinition: + service_dict = await self.state_store.aget( + service_name, collection=self.services_store_key + ) + if service_dict is None: + raise ValueError(f"Service with name {service_name} not found") + + return ServiceDefinition.model_validate(service_dict) + + async def get_all_services(self) -> Dict[str, ServiceDefinition]: + service_dicts = await self.state_store.aget_all( + collection=self.services_store_key + ) + return { + service_name: ServiceDefinition.model_validate(service_dict) + for service_name, service_dict in service_dicts.items() + } + async def create_task(self, task_def: TaskDefinition) -> None: await self.state_store.aput( task_def.task_id, task_def.model_dump(), collection=self.tasks_store_key @@ -238,3 +291,15 @@ async def get_all_tasks(self) -> Dict[str, TaskDefinition]: task_id: TaskDefinition.model_validate(state_dict) for task_id, state_dict in state_dicts.items() } + + +if __name__ == "__main__": + from agentfile import SimpleMessageQueue, AgentOrchestrator + from llama_index.llms.openai import OpenAI + + control_plane = ControlPlaneServer( + SimpleMessageQueue(), AgentOrchestrator(llm=OpenAI()) + ) + import asyncio + + asyncio.run(control_plane.launch_server()) diff --git a/agentfile/launchers/__init__.py b/agentfile/launchers/__init__.py index afa55135..0f576c7a 100644 --- a/agentfile/launchers/__init__.py +++ b/agentfile/launchers/__init__.py @@ -1,3 +1,4 @@ from agentfile.launchers.local import LocalLauncher +from agentfile.launchers.server import ServerLauncher -__all__ = ["LocalLauncher"] +__all__ = ["LocalLauncher", "ServerLauncher"] diff --git a/agentfile/launchers/server.py b/agentfile/launchers/server.py new file mode 100644 index 00000000..3d2243f9 --- /dev/null +++ b/agentfile/launchers/server.py @@ -0,0 +1,101 @@ +import asyncio +import signal +import sys +import uuid +from typing import Any, Callable, Dict, List, Optional + +from agentfile.services.base import BaseService +from agentfile.control_plane.base import BaseControlPlane +from agentfile.message_consumers.base import BaseMessageQueueConsumer +from agentfile.message_queues.simple import SimpleMessageQueue +from agentfile.message_queues.base import PublishCallback +from agentfile.messages.base import QueueMessage +from agentfile.types import ActionTypes +from agentfile.message_publishers.publisher import MessageQueuePublisherMixin + + +class HumanMessageConsumer(BaseMessageQueueConsumer): + message_handler: Dict[str, Callable] + message_type: str = "human" + + async def _process_message(self, message: QueueMessage, **kwargs: Any) -> None: + action = message.action + if action not in self.message_handler: + raise ValueError(f"Action {action} not supported by control plane") + + if action == ActionTypes.COMPLETED_TASK: + await self.message_handler[action](message_data=message.data) + + +class ServerLauncher(MessageQueuePublisherMixin): + def __init__( + self, + services: List[BaseService], + control_plane: BaseControlPlane, + message_queue: SimpleMessageQueue, + publish_callback: Optional[PublishCallback] = None, + ) -> None: + self.services = services + self.control_plane = control_plane + self._message_queue = message_queue + self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}" + self._publish_callback = publish_callback + self.result: Optional[str] = None + + @property + def message_queue(self) -> SimpleMessageQueue: + return self._message_queue + + @property + def publisher_id(self) -> str: + return self._publisher_id + + @property + def publish_callback(self) -> Optional[PublishCallback]: + return self._publish_callback + + def get_shutdown_handler(self, tasks: List[asyncio.Task]) -> Callable: + def signal_handler(sig: Any, frame: Any) -> None: + print("\nShutting down.") + for task in tasks: + task.cancel() + sys.exit(0) + + return signal_handler + + def launch_servers(self) -> None: + return asyncio.run(self.alaunch_servers()) + + async def alaunch_servers(self) -> None: + # launch the message queue + queue_task = asyncio.create_task(self.message_queue.launch_server()) + + # wait for the message queue to be ready + await asyncio.sleep(1) + + # launch the control plane + control_plane_task = asyncio.create_task(self.control_plane.launch_server()) + + # wait for the control plane to be ready + await asyncio.sleep(1) + + # register the control plane as a consumer + await self.message_queue.client.register_consumer( + self.control_plane.as_consumer(remote=True) + ) + + # register the services + control_plane_url = f"http://{self.control_plane.host}:{self.control_plane.port}" # type: ignore + service_tasks = [] + for service in self.services: + service_tasks.append(asyncio.create_task(service.launch_server())) + await service.register_to_message_queue() + await service.register_to_control_plane(control_plane_url) + + shutdown_handler = self.get_shutdown_handler( + [*service_tasks, queue_task, control_plane_task] + ) + loop = asyncio.get_event_loop() + while loop.is_running(): + await asyncio.sleep(0.1) + signal.signal(signal.SIGINT, shutdown_handler) diff --git a/agentfile/message_queues/base.py b/agentfile/message_queues/base.py index 0062651a..476eda73 100644 --- a/agentfile/message_queues/base.py +++ b/agentfile/message_queues/base.py @@ -64,7 +64,8 @@ async def publish( @abstractmethod async def register_consumer( - self, consumer: "BaseMessageQueueConsumer", **kwargs: Any + self, + consumer: "BaseMessageQueueConsumer", ) -> Any: """Register consumer to start consuming messages.""" @@ -92,6 +93,6 @@ async def launch_local(self) -> None: ... @abstractmethod - def launch_server(self) -> None: + async def launch_server(self) -> None: """Launch the service as a server.""" ... diff --git a/agentfile/message_queues/remote_client.py b/agentfile/message_queues/remote_client.py index c53b378d..b31686e0 100644 --- a/agentfile/message_queues/remote_client.py +++ b/agentfile/message_queues/remote_client.py @@ -70,7 +70,7 @@ async def processing_loop(self) -> None: async def launch_local(self) -> None: raise NotImplementedError("`launch_local()` is not implemented for this class.") - def launch_server(self) -> None: + async def launch_server(self) -> None: raise NotImplementedError( "`launch_server()` is not implemented for this class." ) diff --git a/agentfile/message_queues/simple.py b/agentfile/message_queues/simple.py index 84ab8e42..4efe4f6d 100644 --- a/agentfile/message_queues/simple.py +++ b/agentfile/message_queues/simple.py @@ -35,7 +35,7 @@ class SimpleMessageQueue(BaseMessageQueue): ) queues: Dict[str, deque] = Field(default_factory=dict) running: bool = True - port: int = 8003 + port: int = 8001 host: str = "127.0.0.1" _app: FastAPI = PrivateAttr() @@ -45,7 +45,7 @@ def __init__( consumers: Dict[str, Dict[str, BaseMessageQueueConsumer]] = {}, queues: Dict[str, deque] = {}, host: str = "127.0.0.1", - port: int = 8003, + port: int = 8001, ): super().__init__(consumers=consumers, queues=queues, host=host, port=port) @@ -79,6 +79,13 @@ def __init__( tags=["QueueMessages"], ) + @property + def client(self) -> BaseMessageQueue: + from agentfile.message_queues.remote_client import RemoteClientMessageQueue + + base_url = f"http://{self.host}:{self.port}" + return RemoteClientMessageQueue(base_url=base_url) + def _select_consumer(self, message: QueueMessage) -> BaseMessageQueueConsumer: """Select a single consumer to publish a message to.""" message_type_str = message.type @@ -196,6 +203,14 @@ async def launch_local(self) -> None: logger.info("Launching message queue locally") asyncio.create_task(self.processing_loop()) - def launch_server(self) -> None: - logger.info("Launching message queue server") - uvicorn.run(self._app, host=self.host, port=self.port) + async def launch_server(self) -> None: + logger.info(f"Launching message queue server at {self.host}:{self.port}") + + # uvicorn.run(self._app, host=self.host, port=self.port) + class CustomServer(uvicorn.Server): + def install_signal_handlers(self) -> None: + pass + + cfg = uvicorn.Config(self._app, host=self.host, port=self.port) + server = CustomServer(cfg) + await server.serve() diff --git a/agentfile/orchestrators/agent.py b/agentfile/orchestrators/agent.py index 91b2a967..5707f44d 100644 --- a/agentfile/orchestrators/agent.py +++ b/agentfile/orchestrators/agent.py @@ -1,15 +1,16 @@ from typing import Any, Dict, List, Tuple -from llama_index.core.llms import LLM, ChatMessage +from llama_index.core.llms import LLM from llama_index.core.memory import ChatMemoryBuffer from llama_index.core.tools import BaseTool from agentfile.messages.base import QueueMessage from agentfile.orchestrators.base import BaseOrchestrator from agentfile.orchestrators.service_tool import ServiceTool -from agentfile.types import ActionTypes, TaskDefinition, TaskResult +from agentfile.types import ActionTypes, ChatMessage, TaskDefinition, TaskResult HISTORY_KEY = "chat_history" +RESULT_KEY = "result" DEFAULT_SUMMARIZE_TMPL = "{history}\n\nThe above represents the progress so far, please condense the messages into a single message." DEFAULT_FOLLOWUP_TMPL = ( "Pick the next action to take, or return a final response if my original " @@ -59,15 +60,20 @@ async def get_next_messages( # check if there was a tool call queue_messages = [] + result = None if len(response.sources) == 0 or response.sources[0].tool_name == "finalize": + # convert memory chat messages + llama_messages = memory.get_all() + history = [ChatMessage(**x.dict()) for x in llama_messages] + + result = TaskResult( + task_id=task_def.task_id, history=history, result=response.response + ) + queue_messages.append( QueueMessage( type="human", - data=TaskResult( - task_id=task_def.task_id, - history=memory.get_all(), - result=response.response, - ).model_dump(), + data=result.model_dump(), action=ActionTypes.COMPLETED_TASK, ) ) @@ -86,7 +92,10 @@ async def get_next_messages( ) ) - new_state = {HISTORY_KEY: [x.dict() for x in memory.get_all()]} + new_state = { + HISTORY_KEY: [x.dict() for x in memory.get_all()], + RESULT_KEY: result.model_dump() if result is not None else None, + } return queue_messages, new_state async def add_result_to_state( diff --git a/agentfile/orchestrators/pipeline.py b/agentfile/orchestrators/pipeline.py index cc39ab36..aeed766f 100644 --- a/agentfile/orchestrators/pipeline.py +++ b/agentfile/orchestrators/pipeline.py @@ -10,6 +10,11 @@ from agentfile.orchestrators.service_component import ServiceComponent from agentfile.types import ActionTypes, TaskDefinition, TaskResult +RUN_STATE_KEY = "run_state" +NEXT_SERVICE_KEYS = "next_service_keys" +LAST_MODULES_RUN = "last_modules_run" +RESULT_KEY = "result" + class PipelineOrchestrator(BaseOrchestrator): def __init__( @@ -22,10 +27,10 @@ async def get_next_messages( self, task_def: TaskDefinition, tools: List[BaseTool], state: Dict[str, Any] ) -> Tuple[List[QueueMessage], Dict[str, Any]]: # check if we need to init the state - if "run_state" not in state: + if RUN_STATE_KEY not in state: run_state = self.pipeline.get_run_state(input=task_def.input) else: - run_state = pickle.loads(state["run_state"]) + run_state = pickle.loads(state[RUN_STATE_KEY]) # run the next step in the pipeline, until we hit a service component next_module_keys = self.pipeline.get_next_module_keys(run_state) @@ -62,6 +67,7 @@ async def get_next_messages( # run the module if it is not a service component output_dict = await module.arun_component(**module_input) + # check if the output is a service component if "service_output" in output_dict: found_service_component = True next_service_keys.append(module_key) @@ -78,7 +84,7 @@ async def get_next_messages( ) continue - # process the output + # process the output if it is not a service component self.pipeline.process_component_output( output_dict, module_key, @@ -99,9 +105,10 @@ async def get_next_messages( break # did we find a service component? + task_result = None if len(next_service_keys) == 0 and len(next_messages) == 0: # no service component found, return the final result - last_modules_run = state.get("last_modules_run", []) + last_modules_run = state.get(LAST_MODULES_RUN, []) result_dict = run_state.result_outputs[module_key or last_modules_run[-1]] if len(result_dict) == 1: @@ -109,23 +116,27 @@ async def get_next_messages( else: result = str(result_dict) + task_result = TaskResult( + task_id=task_def.task_id, + result=result, + history=[], + ) + next_messages.append( QueueMessage( type="human", action=ActionTypes.COMPLETED_TASK, - data=TaskResult( - task_id=task_def.task_id, - result=result, - history=[], - ).model_dump(), + data=task_result.model_dump(), ) ) - new_state = { - "run_state": pickle.dumps(run_state), - "next_service_keys": next_service_keys, - } - return next_messages, new_state + state[RUN_STATE_KEY] = pickle.dumps(run_state) + state[NEXT_SERVICE_KEYS] = next_service_keys + state[RESULT_KEY] = ( + task_result.model_dump() if task_result is not None else None + ) + + return next_messages, state async def add_result_to_state( self, @@ -134,8 +145,8 @@ async def add_result_to_state( ) -> Dict[str, Any]: """Add the result of processing a message to the state. Returns the new state.""" - run_state = pickle.loads(state["run_state"]) - next_service_keys = state["next_service_keys"] + run_state = pickle.loads(state[RUN_STATE_KEY]) + next_service_keys = state[NEXT_SERVICE_KEYS] # process the output of the service component(s) for module_key in next_service_keys: @@ -145,7 +156,6 @@ async def add_result_to_state( run_state, ) - return { - "run_state": pickle.dumps(run_state), - "last_modules_run": next_service_keys, - } + state[RUN_STATE_KEY] = pickle.dumps(run_state) + state[LAST_MODULES_RUN] = next_service_keys + return state diff --git a/agentfile/services/agent.py b/agentfile/services/agent.py index 69e85b75..d73b9d4e 100644 --- a/agentfile/services/agent.py +++ b/agentfile/services/agent.py @@ -7,7 +7,6 @@ from typing import AsyncGenerator, Dict, List, Literal, Optional from llama_index.core.agent import AgentRunner -from llama_index.core.llms import ChatMessage from agentfile.message_consumers.base import BaseMessageQueueConsumer from agentfile.message_consumers.callable import CallableMessageConsumer @@ -19,6 +18,7 @@ from agentfile.services.types import _ChatMessage from agentfile.types import ( ActionTypes, + ChatMessage, TaskResult, TaskDefinition, ServiceDefinition, @@ -114,6 +114,8 @@ def service_definition(self) -> ServiceDefinition: service_name=self.service_name, description=self.description, prompt=self.prompt or [], + host=self.host, + port=self.port, ) @property @@ -153,8 +155,9 @@ async def processing_loop(self) -> None: task_id, step_output=step_output ) - # get the latest history - history = self.agent.memory.get() + # convert memory chat messages + llama_messages = self.agent.memory.get() + history = [ChatMessage(**x.dict()) for x in llama_messages] # publish the completed task await self.publish( @@ -183,7 +186,7 @@ async def process_message(self, message: QueueMessage) -> None: def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: if remote: - url = f"{self.host}:{self.port}/{self._app.url_path_for('process_message')}" + url = f"http://{self.host}:{self.port}{self._app.url_path_for('process_message')}" return RemoteMessageConsumer( url=url, message_type=self.service_name, @@ -208,11 +211,29 @@ async def lifespan(self, app: FastAPI) -> AsyncGenerator[None, None]: self.running = False async def home(self) -> Dict[str, str]: + tasks = self.agent.list_tasks() + + task_strings = [] + for task in tasks: + task_output = self.agent.get_task_output(task.task_id) + status = "COMPLETE" if task_output.is_last else "IN PROGRESS" + memory_str = "\n".join( + [f"{x.role}: {x.content}" for x in task.memory.get_all()] + ) + task_strings.append(f"Agent Task {task.task_id}: {status}\n{memory_str}") + + complete_task_string = "\n".join(task_strings) + return { "service_name": self.service_name, "description": self.description, "running": str(self.running), "step_interval": str(self.step_interval), + "num_tasks": str(len(tasks)), + "num_completed_tasks": str(len(self.agent.get_completed_tasks())), + "prompt": "\n".join([str(x) for x in self.prompt]) if self.prompt else "", + "type": "agent_service", + "tasks": complete_task_string, } async def create_task(self, task: TaskDefinition) -> Dict[str, str]: @@ -239,5 +260,14 @@ async def reset_agent(self) -> Dict[str, str]: return {"message": "Agent reset"} - def launch_server(self) -> None: - uvicorn.run(self._app, host=self.host, port=self.port) + async def launch_server(self) -> None: + logger.info(f"Launching {self.service_name} server at {self.host}:{self.port}") + # uvicorn.run(self._app, host=self.host, port=self.port) + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self) -> None: + pass + + cfg = uvicorn.Config(self._app, host=self.host, port=self.port) + server = CustomServer(cfg) + await server.serve() diff --git a/agentfile/services/base.py b/agentfile/services/base.py index 43cd53b7..802c9f2e 100644 --- a/agentfile/services/base.py +++ b/agentfile/services/base.py @@ -1,3 +1,4 @@ +import httpx from abc import ABC, abstractmethod from pydantic import BaseModel from typing import Any @@ -50,6 +51,20 @@ async def launch_local(self) -> None: ... @abstractmethod - def launch_server(self) -> None: + async def launch_server(self) -> None: """Launch the service as a server.""" ... + + async def register_to_control_plane(self, control_plane_url: str) -> None: + """Register the service to the control plane.""" + service_def = self.service_definition + async with httpx.AsyncClient() as client: + response = await client.post( + f"{control_plane_url}/services/register", + json=service_def.model_dump(), + ) + response.raise_for_status() + + async def register_to_message_queue(self) -> None: + """Register the service to the message queue.""" + await self.message_queue.register_consumer(self.as_consumer(remote=True)) diff --git a/agentfile/services/human.py b/agentfile/services/human.py index 291f480c..afdcd793 100644 --- a/agentfile/services/human.py +++ b/agentfile/services/human.py @@ -3,10 +3,11 @@ import uuid import uvicorn from asyncio import Lock -from contextlib import asynccontextmanager from fastapi import FastAPI -from typing import Any, AsyncGenerator, Dict, Optional -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import PrivateAttr +from typing import Dict, List, Optional + +from llama_index.core.llms import MessageRole from agentfile.message_consumers.base import BaseMessageQueueConsumer from agentfile.message_consumers.callable import CallableMessageConsumer @@ -17,13 +18,14 @@ from agentfile.services.base import BaseService from agentfile.types import ( ActionTypes, + ChatMessage, + HumanResponse, TaskDefinition, TaskResult, ServiceDefinition, CONTROL_PLANE_NAME, - generate_id, ) -from llama_index.core.llms import ChatMessage, MessageRole + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -45,7 +47,7 @@ class HumanService(BaseService): host: Optional[str] = None port: Optional[int] = None - _outstanding_human_tasks: Dict[str, "HumanTask"] = PrivateAttr() + _outstanding_human_tasks: List[TaskDefinition] = PrivateAttr() _message_queue: BaseMessageQueue = PrivateAttr() _app: FastAPI = PrivateAttr() _publisher_id: str = PrivateAttr() @@ -72,17 +74,35 @@ def __init__( port=port, ) - self._outstanding_human_tasks = {} + self._outstanding_human_tasks = [] self._message_queue = message_queue self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}" self._publish_callback = publish_callback self._lock = asyncio.Lock() - self._app = FastAPI(lifespan=self.lifespan) + self._app = FastAPI() self._app.add_api_route("/", self.home, methods=["GET"], tags=["Human Service"]) + self._app.add_api_route( + "/process_message", + self.process_message, + methods=["POST"], + tags=["Human Service"], + ) self._app.add_api_route( - "/help", self.create_task, methods=["POST"], tags=["Help Requests"] + "/tasks", self.create_task, methods=["POST"], tags=["Tasks"] + ) + self._app.add_api_route( + "/tasks", self.get_tasks, methods=["GET"], tags=["Tasks"] + ) + self._app.add_api_route( + "/tasks/{task_id}", self.get_task, methods=["GET"], tags=["Tasks"] + ) + self._app.add_api_route( + "/tasks/{task_id}/handle", + self.handle_task, + methods=["POST"], + tags=["Tasks"], ) @property @@ -91,6 +111,8 @@ def service_definition(self) -> ServiceDefinition: service_name=self.service_name, description=self.description, prompt=[], + host=self.host, + port=self.port, ) @property @@ -109,21 +131,6 @@ def publish_callback(self) -> Optional[PublishCallback]: def lock(self) -> Lock: return self._lock - class HumanTask(BaseModel): - """Lightweight container object over TaskDefinitions. - - This is needed since orchestrators may send multiple `TaskDefinition` - with the same task_id. In such a case, this human service is expected - to address these multiple (sub)tasks for the overall task. In other words, - these sub tasks are all legitimate and should be processed. - """ - - id_: str = Field(default_factory=generate_id) - task_definition: TaskDefinition - - class Config: - arbitrary_types_allowed = True - async def processing_loop(self) -> None: while True: if not self.running: @@ -131,9 +138,12 @@ async def processing_loop(self) -> None: continue async with self.lock: - current_human_tasks = [*self._outstanding_human_tasks.values()] - for human_task in current_human_tasks: - task_def = human_task.task_definition + try: + task_def = self._outstanding_human_tasks.pop(0) + except IndexError: + await asyncio.sleep(self.step_interval) + continue + logger.info( f"Processing request for human help for task: {task_def.task_id}" ) @@ -167,24 +177,19 @@ async def processing_loop(self) -> None: ) ) - # clean up - async with self.lock: - del self._outstanding_human_tasks[human_task.id_] - await asyncio.sleep(self.step_interval) - async def process_message(self, message: QueueMessage, **kwargs: Any) -> None: + async def process_message(self, message: QueueMessage) -> None: if message.action == ActionTypes.NEW_TASK: task_def = TaskDefinition(**message.data or {}) - human_task = self.HumanTask(task_definition=task_def) async with self.lock: - self._outstanding_human_tasks.update({human_task.id_: human_task}) + self._outstanding_human_tasks.append(task_def) else: raise ValueError(f"Unhandled action: {message.action}") def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: if remote: - url = f"{self.host}:{self.port}/{self._app.url_path_for('process_message')}" + url = f"http://{self.host}:{self.port}{self._app.url_path_for('process_message')}" return RemoteMessageConsumer( url=url, message_type=self.service_name, @@ -201,29 +206,76 @@ async def launch_local(self) -> None: # ---- Server based methods ---- - @asynccontextmanager - async def lifespan(self) -> AsyncGenerator[None, None]: - """Starts the processing loop when the fastapi app starts.""" - asyncio.create_task(self.processing_loop()) - yield - self.running = False - async def home(self) -> Dict[str, str]: return { "service_name": self.service_name, "description": self.description, "running": str(self.running), "step_interval": str(self.step_interval), + "num_tasks": str(len(self._outstanding_human_tasks)), + "tasks": "\n".join([str(task) for task in self._outstanding_human_tasks]), + "type": "human_service", } async def create_task(self, task: TaskDefinition) -> Dict[str, str]: - human_task = self.HumanTask(task_definition=task) async with self.lock: - self._outstanding_human_tasks.update({human_task.id_: human_task}) + self._outstanding_human_tasks.append(task) return {"task_id": task.task_id} - def launch_server(self) -> None: - uvicorn.run(self._app, host=self.host, port=self.port) + async def get_tasks(self) -> List[TaskDefinition]: + async with self.lock: + return [*self._outstanding_human_tasks] + + async def get_task(self, task_id: str) -> Optional[TaskDefinition]: + async with self.lock: + for task in self._outstanding_human_tasks: + if task.task_id == task_id: + return task + return None + + async def handle_task(self, task_id: str, result: HumanResponse) -> None: + async with self.lock: + for task_def in self._outstanding_human_tasks: + if task_def.task_id == task_id: + self._outstanding_human_tasks.remove(task_def) + break + + logger.info(f"Processing request for human help for task: {task_def.task_id}") + + # create history + history = [ + ChatMessage( + role=MessageRole.ASSISTANT, + content=HELP_REQUEST_TEMPLATE_STR.format(input_str=task_def.input), + ), + ChatMessage(role=MessageRole.USER, content=result.result), + ] + + # publish the completed task + await self.publish( + QueueMessage( + type=CONTROL_PLANE_NAME, + action=ActionTypes.COMPLETED_TASK, + data=TaskResult( + task_id=task_def.task_id, + history=history, + result=result.result, + ).model_dump(), + ) + ) + + async def launch_server(self) -> None: + logger.info( + f"Lanching server for {self.service_name} at {self.host}:{self.port}" + ) + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self) -> None: + pass + + cfg = uvicorn.Config(self._app, host=self.host, port=self.port) + server = CustomServer(cfg) + await server.serve() HumanService.model_rebuild() diff --git a/agentfile/services/tool.py b/agentfile/services/tool.py index d1edae03..a903dcb2 100644 --- a/agentfile/services/tool.py +++ b/agentfile/services/tool.py @@ -11,7 +11,7 @@ from llama_index.core.agent.function_calling.step import ( get_function_by_name, ) -from llama_index.core.llms import ChatMessage, MessageRole +from llama_index.core.llms import MessageRole from llama_index.core.tools import BaseTool, AsyncBaseTool, adapt_to_async_tool from agentfile.message_consumers.base import BaseMessageQueueConsumer @@ -23,6 +23,7 @@ from agentfile.services.base import BaseService from agentfile.types import ( ActionTypes, + ChatMessage, ToolCall, ToolCallResult, ServiceDefinition, @@ -103,6 +104,8 @@ def service_definition(self) -> ServiceDefinition: service_name=self.service_name, description=self.description, prompt=[], + host=self.host, + port=self.port, ) @property @@ -185,7 +188,7 @@ async def process_message(self, message: QueueMessage) -> None: def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: if remote: - url = f"{self.host}:{self.port}/{self._app.url_path_for('process_message')}" + url = f"http://{self.host}:{self.port}{self._app.url_path_for('process_message')}" return RemoteMessageConsumer( url=url, message_type=self.service_name, @@ -213,6 +216,12 @@ async def home(self) -> Dict[str, str]: "description": self.description, "running": str(self.running), "step_interval": str(self.step_interval), + "num_tools": str(len(self.tools)), + "num_outstanding_tool_calls": str(len(self._outstanding_tool_calls)), + "tool_calls": "\n".join( + [str(tool_call) for tool_call in self._outstanding_tool_calls.values()] + ), + "type": "tool_service", } async def create_tool_call(self, tool_call: ToolCall) -> Dict[str, str]: @@ -226,5 +235,14 @@ async def get_tool_by_name(self, name: str) -> Dict[str, Any]: raise ValueError(f"Tool with name {name} not found") return {"tool_metadata": name_to_tool[name].metadata} - def launch_server(self) -> None: - uvicorn.run(self._app, host=self.host, port=self.port) + async def launch_server(self) -> None: + logger.info(f"Launching tool service server at {self.host}:{self.port}") + # uvicorn.run(self._app, host=self.host, port=self.port) + + class CustomServer(uvicorn.Server): + def install_signal_handlers(self) -> None: + pass + + cfg = uvicorn.Config(self._app, host=self.host, port=self.port) + server = CustomServer(cfg) + await server.serve() diff --git a/agentfile/services/types.py b/agentfile/services/types.py index c38d7f0e..b17e0bc8 100644 --- a/agentfile/services/types.py +++ b/agentfile/services/types.py @@ -3,7 +3,8 @@ from llama_index.core.agent.types import TaskStep, TaskStepOutput, Task from llama_index.core.agent.runner.base import AgentState, TaskState -from llama_index.core.llms import ChatMessage + +from agentfile.types import ChatMessage # ------ FastAPI types ------ diff --git a/agentfile/types.py b/agentfile/types.py index 00a9e2ac..f9ca334e 100644 --- a/agentfile/types.py +++ b/agentfile/types.py @@ -1,9 +1,10 @@ import uuid from enum import Enum -from pydantic import BaseModel, Field, SkipValidation -from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field +from pydantic.v1 import BaseModel as V1BaseModel +from typing import Any, Dict, List, Optional, Union -from llama_index.core.llms import ChatMessage +from llama_index.core.llms import MessageRole def generate_id() -> str: @@ -13,6 +14,57 @@ def generate_id() -> str: CONTROL_PLANE_NAME = "control_plane" +class ChatMessage(BaseModel): + """Chat message. + + TODO: Temp copy of class from llama-index, to avoid pydantic v1/v2 issues. + """ + + role: MessageRole = MessageRole.USER + content: Optional[Any] = "" + additional_kwargs: dict = Field(default_factory=dict) + + def __str__(self) -> str: + return f"{self.role.value}: {self.content}" + + @classmethod + def from_str( + cls, + content: str, + role: Union[MessageRole, str] = MessageRole.USER, + **kwargs: Any, + ) -> "ChatMessage": + if isinstance(role, str): + role = MessageRole(role) + return cls(role=role, content=content, **kwargs) + + def _recursive_serialization(self, value: Any) -> Any: + if isinstance(value, (V1BaseModel, BaseModel)): + return value.dict() + if isinstance(value, dict): + return { + key: self._recursive_serialization(value) + for key, value in value.items() + } + if isinstance(value, list): + return [self._recursive_serialization(item) for item in value] + return value + + def dict(self, **kwargs: Any) -> dict: + # ensure all additional_kwargs are serializable + msg = super().dict(**kwargs) + + for key, value in msg.get("additional_kwargs", {}).items(): + value = self._recursive_serialization(value) + if not isinstance(value, (str, int, float, bool, dict, list, type(None))): + raise ValueError( + f"Failed to serialize additional_kwargs value: {value}" + ) + msg["additional_kwargs"][key] = value + + return msg + + class ActionTypes(str, Enum): NEW_TASK = "new_task" COMPLETED_TASK = "completed_task" @@ -30,7 +82,7 @@ class TaskDefinition(BaseModel): class TaskResult(BaseModel): task_id: str - history: SkipValidation[List[ChatMessage]] + history: List[ChatMessage] result: str @@ -48,7 +100,7 @@ class ToolCall(BaseModel): class ToolCallResult(BaseModel): id_: str - tool_message: SkipValidation[ChatMessage] + tool_message: ChatMessage result: str @@ -57,6 +109,12 @@ class ServiceDefinition(BaseModel): description: str = Field( description="A description of the service and it's purpose." ) - prompt: List[SkipValidation[ChatMessage]] = Field( + prompt: List[ChatMessage] = Field( default_factory=list, description="Specific instructions for the service." ) + host: Optional[str] = None + port: Optional[int] = None + + +class HumanResponse(BaseModel): + result: str diff --git a/example_scripts/agentic_human_local_single.py b/example_scripts/agentic_human_local_single.py index 425f08e0..79338161 100644 --- a/example_scripts/agentic_human_local_single.py +++ b/example_scripts/agentic_human_local_single.py @@ -1,6 +1,6 @@ from agentfile.launchers.local import LocalLauncher from agentfile.services import HumanService, AgentService -from agentfile.control_plane.fastapi import FastAPIControlPlane +from agentfile.control_plane.server import ControlPlaneServer from agentfile.message_queues.simple import SimpleMessageQueue from agentfile.orchestrators.agent import AgentOrchestrator @@ -33,7 +33,7 @@ def get_the_secret_fact() -> str: message_queue=message_queue, description="Answers queries about math." ) -control_plane = FastAPIControlPlane( +control_plane = ControlPlaneServer( message_queue=message_queue, orchestrator=AgentOrchestrator(llm=OpenAI()), ) diff --git a/example_scripts/agentic_local_single.py b/example_scripts/agentic_local_single.py index 1755deeb..7788a825 100644 --- a/example_scripts/agentic_local_single.py +++ b/example_scripts/agentic_local_single.py @@ -1,7 +1,7 @@ from agentfile import ( AgentService, AgentOrchestrator, - FastAPIControlPlane, + ControlPlaneServer, LocalLauncher, SimpleMessageQueue, ) @@ -26,7 +26,7 @@ def get_the_secret_fact() -> str: # create our multi-agent framework components message_queue = SimpleMessageQueue() -control_plane = FastAPIControlPlane( +control_plane = ControlPlaneServer( message_queue=message_queue, orchestrator=AgentOrchestrator(llm=OpenAI()), ) diff --git a/example_scripts/agentic_server.py b/example_scripts/agentic_server.py new file mode 100644 index 00000000..51f44b07 --- /dev/null +++ b/example_scripts/agentic_server.py @@ -0,0 +1,64 @@ +from agentfile import ( + AgentService, + HumanService, + AgentOrchestrator, + ControlPlaneServer, + ServerLauncher, + SimpleMessageQueue, +) + +from llama_index.core.agent import FunctionCallingAgentWorker +from llama_index.core.tools import FunctionTool +from llama_index.llms.openai import OpenAI + + +# create an agent +def get_the_secret_fact() -> str: + """Returns the secret fact.""" + return "The secret fact is: A baby llama is called a 'Cria'." + + +tool = FunctionTool.from_defaults(fn=get_the_secret_fact) + +worker1 = FunctionCallingAgentWorker.from_tools([tool], llm=OpenAI()) +worker2 = FunctionCallingAgentWorker.from_tools([], llm=OpenAI()) +agent1 = worker1.as_agent() +agent2 = worker2.as_agent() + +# create our multi-agent framework components +message_queue = SimpleMessageQueue() +queue_client = message_queue.client + +control_plane = ControlPlaneServer( + message_queue=queue_client, + orchestrator=AgentOrchestrator(llm=OpenAI()), +) +agent_server_1 = AgentService( + agent=agent1, + message_queue=queue_client, + description="Useful for getting the secret fact.", + service_name="secret_fact_agent", + host="127.0.0.1", + port=8002, +) +agent_server_2 = AgentService( + agent=agent2, + message_queue=queue_client, + description="Useful for getting random dumb facts.", + service_name="dumb_fact_agent", + host="127.0.0.1", + port=8003, +) +human_service = HumanService( + message_queue=queue_client, + description="Answers queries about math.", + host="127.0.0.1", + port=8004, +) + +# launch it +launcher = ServerLauncher( + [agent_server_1, agent_server_2, human_service], control_plane, message_queue +) + +launcher.launch_servers() diff --git a/example_scripts/agentic_toolservice_local_single.py b/example_scripts/agentic_toolservice_local_single.py index 343fbc1e..b816ba12 100644 --- a/example_scripts/agentic_toolservice_local_single.py +++ b/example_scripts/agentic_toolservice_local_single.py @@ -1,7 +1,7 @@ from agentfile.launchers.local import LocalLauncher from agentfile.services import AgentService, ToolService from agentfile.tools import MetaServiceTool -from agentfile.control_plane.fastapi import FastAPIControlPlane +from agentfile.control_plane.fastapi import ControlPlaneServer from agentfile.message_queues.simple import SimpleMessageQueue from agentfile.orchestrators.agent import AgentOrchestrator @@ -28,7 +28,7 @@ def get_the_secret_fact() -> str: step_interval=0.5, ) -control_plane = FastAPIControlPlane( +control_plane = ControlPlaneServer( message_queue=message_queue, orchestrator=AgentOrchestrator(llm=OpenAI()), ) diff --git a/example_scripts/pipeline_human_local_single.py b/example_scripts/pipeline_human_local_single.py index 7d7dc1c9..fc5e56b6 100644 --- a/example_scripts/pipeline_human_local_single.py +++ b/example_scripts/pipeline_human_local_single.py @@ -1,7 +1,7 @@ from agentfile import ( AgentService, HumanService, - FastAPIControlPlane, + ControlPlaneServer, SimpleMessageQueue, PipelineOrchestrator, ServiceComponent, @@ -53,7 +53,7 @@ def get_the_secret_fact() -> str: pipeline_orchestrator = PipelineOrchestrator(pipeline) -control_plane = FastAPIControlPlane(message_queue, pipeline_orchestrator) +control_plane = ControlPlaneServer(message_queue, pipeline_orchestrator) # launch it launcher = LocalLauncher([agent_service, human_service], control_plane, message_queue) diff --git a/example_scripts/pipeline_local_single.py b/example_scripts/pipeline_local_single.py index 08da7544..8514b646 100644 --- a/example_scripts/pipeline_local_single.py +++ b/example_scripts/pipeline_local_single.py @@ -1,6 +1,6 @@ from agentfile import ( AgentService, - FastAPIControlPlane, + ControlPlaneServer, SimpleMessageQueue, PipelineOrchestrator, ServiceComponent, @@ -58,7 +58,7 @@ def get_the_secret_fact() -> str: pipeline_orchestrator = PipelineOrchestrator(pipeline) -control_plane = FastAPIControlPlane(message_queue, pipeline_orchestrator) +control_plane = ControlPlaneServer(message_queue, pipeline_orchestrator) # launch it launcher = LocalLauncher([agent_server_1, agent_server_2], control_plane, message_queue) diff --git a/poetry.lock b/poetry.lock index 2d223505..6988abe6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -808,6 +808,26 @@ files = [ {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, ] +[[package]] +name = "linkify-it-py" +version = "2.0.3" +description = "Links recognition library with FULL unicode support." +optional = false +python-versions = ">=3.7" +files = [ + {file = "linkify-it-py-2.0.3.tar.gz", hash = "sha256:68cda27e162e9215c17d786649d1da0021a451bdc436ef9e0fa0ba5234b9b048"}, + {file = "linkify_it_py-2.0.3-py3-none-any.whl", hash = "sha256:6bcbc417b0ac14323382aef5c5192c0075bf8a9d6b41820a2b66371eac6b6d79"}, +] + +[package.dependencies] +uc-micro-py = "*" + +[package.extras] +benchmark = ["pytest", "pytest-benchmark"] +dev = ["black", "flake8", "isort", "pre-commit", "pyproject-flake8"] +doc = ["myst-parser", "sphinx", "sphinx-book-theme"] +test = ["coverage", "pytest", "pytest-cov"] + [[package]] name = "llama-index-agent-openai" version = "0.2.7" @@ -915,6 +935,8 @@ files = [ ] [package.dependencies] +linkify-it-py = {version = ">=1,<3", optional = true, markers = "extra == \"linkify\""} +mdit-py-plugins = {version = "*", optional = true, markers = "extra == \"plugins\""} mdurl = ">=0.1,<1.0" [package.extras] @@ -1015,6 +1037,25 @@ dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"] docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] tests = ["pytest", "pytz", "simplejson"] +[[package]] +name = "mdit-py-plugins" +version = "0.4.1" +description = "Collection of plugins for markdown-it-py" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mdit_py_plugins-0.4.1-py3-none-any.whl", hash = "sha256:1020dfe4e6bfc2c79fb49ae4e3f5b297f5ccd20f010187acc52af2921e27dc6a"}, + {file = "mdit_py_plugins-0.4.1.tar.gz", hash = "sha256:834b8ac23d1cd60cec703646ffd22ae97b7955a6d596eb1d304be1e251ae499c"}, +] + +[package.dependencies] +markdown-it-py = ">=1.0.0,<4.0.0" + +[package.extras] +code-style = ["pre-commit"] +rtd = ["myst-parser", "sphinx-book-theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "mdurl" version = "0.1.2" @@ -2114,6 +2155,25 @@ files = [ doc = ["reno", "sphinx"] test = ["pytest", "tornado (>=4.5)", "typeguard"] +[[package]] +name = "textual" +version = "0.70.0" +description = "Modern Text User Interface framework" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "textual-0.70.0-py3-none-any.whl", hash = "sha256:774bf45782193760ca273b915fd685cada37d0836237d61dc57d5bcdbe2c7ddb"}, + {file = "textual-0.70.0.tar.gz", hash = "sha256:9ca3f615b5cf442246325e40ef8255424c42b4241d3c62f9c0f96951bab82b1e"}, +] + +[package.dependencies] +markdown-it-py = {version = ">=2.1.0", extras = ["linkify", "plugins"]} +rich = ">=13.3.3" +typing-extensions = ">=4.4.0,<5.0.0" + +[package.extras] +syntax = ["tree-sitter (>=0.20.1,<0.21.0)", "tree-sitter-languages (==1.10.2)"] + [[package]] name = "tiktoken" version = "0.7.0" @@ -2251,6 +2311,20 @@ files = [ {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, ] +[[package]] +name = "uc-micro-py" +version = "1.0.3" +description = "Micro subset of unicode data files for linkify-it-py projects." +optional = false +python-versions = ">=3.7" +files = [ + {file = "uc-micro-py-1.0.3.tar.gz", hash = "sha256:d321b92cff673ec58027c04015fcaa8bb1e005478643ff4a500882eaab88c48a"}, + {file = "uc_micro_py-1.0.3-py3-none-any.whl", hash = "sha256:db1dffff340817673d7b466ec86114a9dc0e9d4d9b5ba229d9d60e5c12600cd5"}, +] + +[package.extras] +test = ["coverage", "pytest", "pytest-cov"] + [[package]] name = "ujson" version = "5.10.0" @@ -2778,4 +2852,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "e85b30c20285d5bcf0bbcae61542a41081761d20698774fef3a63a5af95c32f2" +content-hash = "9019caabc36284e23d753f092846a8f80f7262c33b14ff49cb3fba29b60e465d" diff --git a/pyproject.toml b/pyproject.toml index 9cf9205b..90e7af0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,11 +13,13 @@ readme = "README.md" python = ">=3.8.1,<4.0" fastapi = "^0.111.0" llama-index-core = "^0.10.47" -llama-index-agent-openai = "^0.2.5" -llama-index-embeddings-openai = "^0.1.10" pytest-asyncio = "^0.23.7" +textual = "^0.70.0" [tool.poetry.group.dev.dependencies] pytest = "^8.2.2" ruff = "^0.4.7" mypy = "^1.10.0" + +[tool.poetry.scripts] +llama-agents = 'agentfile.cli.command_line:main' diff --git a/tests/orchestrators/test_agent_orchestrator.py b/tests/orchestrators/test_agent_orchestrator.py index 31ab869e..f02201f3 100644 --- a/tests/orchestrators/test_agent_orchestrator.py +++ b/tests/orchestrators/test_agent_orchestrator.py @@ -1,9 +1,8 @@ -from llama_index.core.base.llms.types import ChatMessage -from llama_index.core.chat_engine.types import AgentChatResponse -from llama_index.core.tools.types import BaseTool, ToolOutput import pytest from typing import Any, List, Optional, Union +from llama_index.core.chat_engine.types import AgentChatResponse +from llama_index.core.tools.types import BaseTool, ToolOutput from llama_index.core.llms import ( CustomLLM, CompletionResponse, @@ -14,7 +13,7 @@ from agentfile.orchestrators.agent import AgentOrchestrator from agentfile.messages.base import QueueMessage from agentfile.orchestrators.service_tool import ServiceTool -from agentfile.types import ActionTypes, TaskDefinition, TaskResult +from agentfile.types import ActionTypes, ChatMessage, TaskDefinition, TaskResult TASK_DEF = TaskDefinition( input="Tell me a secret fact.", diff --git a/tests/services/test_human_service.py b/tests/services/test_human_service.py index 7d10e164..023f9260 100644 --- a/tests/services/test_human_service.py +++ b/tests/services/test_human_service.py @@ -64,7 +64,7 @@ async def test_create_task(mock_uuid: MagicMock) -> None: # assert assert result == {"task_id": task.task_id} - assert human_service._outstanding_human_tasks["mock_id"].task_definition == task + assert human_service._outstanding_human_tasks[0] == task @pytest.mark.asyncio()