From 86b756164e36e29f283497a8032d8590da788a0f Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Mon, 23 Dec 2024 09:33:03 +0100 Subject: [PATCH 1/4] remove tools package --- .pre-commit-config.yaml | 2 +- .../workflows/agent_workflow.py | 4 +- llama_deploy/__init__.py | 17 - llama_deploy/services/__init__.py | 8 +- llama_deploy/services/agent.py | 448 ----------------- llama_deploy/services/human.py | 449 ------------------ llama_deploy/tools/__init__.py | 14 - llama_deploy/tools/agent_service_tool.py | 4 - llama_deploy/tools/meta_service_tool.py | 274 ----------- llama_deploy/tools/service_as_tool.py | 287 ----------- llama_deploy/tools/service_component.py | 142 ------ llama_deploy/tools/service_tool.py | 45 -- llama_deploy/tools/utils.py | 6 - pyproject.toml | 3 +- tests/services/test_agent_service.py | 23 - tests/services/test_human_service.py | 311 ------------ tests/tools/test_agent_service_as_tool.py | 218 --------- tests/tools/test_human_service_as_tool.py | 145 ------ tests/tools/test_meta_service_tool.py | 223 --------- tests/tools/test_service_as_tool.py | 135 ------ 20 files changed, 6 insertions(+), 2752 deletions(-) delete mode 100644 llama_deploy/services/agent.py delete mode 100644 llama_deploy/services/human.py delete mode 100644 llama_deploy/tools/__init__.py delete mode 100644 llama_deploy/tools/agent_service_tool.py delete mode 100644 llama_deploy/tools/meta_service_tool.py delete mode 100644 llama_deploy/tools/service_as_tool.py delete mode 100644 llama_deploy/tools/service_component.py delete mode 100644 llama_deploy/tools/service_tool.py delete mode 100644 llama_deploy/tools/utils.py delete mode 100644 tests/services/test_agent_service.py delete mode 100644 tests/services/test_human_service.py delete mode 100644 tests/tools/test_agent_service_as_tool.py delete mode 100644 tests/tools/test_human_service_as_tool.py delete mode 100644 tests/tools/test_meta_service_tool.py delete mode 100644 tests/tools/test_service_as_tool.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3903d357..6e8fc5ea 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: --ignore-missing-imports, --python-version=3.11, ] - exclude: ^(examples/|e2e_tests/|tests/tools/) + exclude: ^(examples/|e2e_tests/) - repo: https://github.com/adamchainz/blacken-docs rev: 1.16.0 diff --git a/examples/python_fullstack/workflows/agent_workflow.py b/examples/python_fullstack/workflows/agent_workflow.py index 4cfb6a94..54fe9b48 100644 --- a/examples/python_fullstack/workflows/agent_workflow.py +++ b/examples/python_fullstack/workflows/agent_workflow.py @@ -3,16 +3,16 @@ from llama_index.core.llms import ChatMessage from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.tools import FunctionTool from llama_index.core.workflow import ( Context, Event, - Workflow, StartEvent, StopEvent, + Workflow, step, ) from llama_index.llms.openai import OpenAI -from llama_index.core.tools import FunctionTool from .rag_workflow import RAGWorkflow diff --git a/llama_deploy/__init__.py b/llama_deploy/__init__.py index ecdbf1fa..788ff1de 100644 --- a/llama_deploy/__init__.py +++ b/llama_deploy/__init__.py @@ -12,20 +12,11 @@ from llama_deploy.messages import QueueMessage from llama_deploy.orchestrators import SimpleOrchestrator, SimpleOrchestratorConfig from llama_deploy.services import ( - AgentService, ComponentService, - HumanService, ToolService, WorkflowService, WorkflowServiceConfig, ) -from llama_deploy.tools import ( - AgentServiceTool, - MetaServiceTool, - ServiceAsTool, - ServiceComponent, - ServiceTool, -) root_logger = logging.getLogger("llama_deploy") @@ -44,8 +35,6 @@ "AsyncLlamaDeployClient", "Client", # services - "AgentService", - "HumanService", "ToolService", "ComponentService", "WorkflowService", @@ -66,10 +55,4 @@ # orchestrators "SimpleOrchestrator", "SimpleOrchestratorConfig", - # various utils - "AgentServiceTool", - "ServiceAsTool", - "ServiceComponent", - "ServiceTool", - "MetaServiceTool", ] diff --git a/llama_deploy/services/__init__.py b/llama_deploy/services/__init__.py index fc6744fa..45f64223 100644 --- a/llama_deploy/services/__init__.py +++ b/llama_deploy/services/__init__.py @@ -1,21 +1,17 @@ from llama_deploy.services.base import BaseService -from llama_deploy.services.agent import AgentService -from llama_deploy.services.human import HumanService -from llama_deploy.services.tool import ToolService from llama_deploy.services.component import ComponentService +from llama_deploy.services.tool import ToolService from llama_deploy.services.types import ( + _ChatMessage, _Task, _TaskSate, _TaskStep, _TaskStepOutput, - _ChatMessage, ) from llama_deploy.services.workflow import WorkflowService, WorkflowServiceConfig __all__ = [ "BaseService", - "AgentService", - "HumanService", "ToolService", "ComponentService", "WorkflowService", diff --git a/llama_deploy/services/agent.py b/llama_deploy/services/agent.py deleted file mode 100644 index 65627107..00000000 --- a/llama_deploy/services/agent.py +++ /dev/null @@ -1,448 +0,0 @@ -import asyncio -import uuid -from contextlib import asynccontextmanager -from logging import getLogger -from typing import AsyncGenerator, Dict, List, Literal, Optional, cast - -import uvicorn -from fastapi import FastAPI -from llama_index.core.agent import AgentRunner -from pydantic import PrivateAttr - -from llama_deploy.control_plane.server import CONTROL_PLANE_MESSAGE_TYPE -from llama_deploy.message_consumers.base import BaseMessageQueueConsumer -from llama_deploy.message_consumers.callable import CallableMessageConsumer -from llama_deploy.message_consumers.remote import RemoteMessageConsumer -from llama_deploy.message_publishers.publisher import PublishCallback -from llama_deploy.message_queues.base import BaseMessageQueue -from llama_deploy.messages.base import QueueMessage -from llama_deploy.services.base import BaseService -from llama_deploy.services.types import _ChatMessage -from llama_deploy.tools.utils import get_tool_name_from_service_name -from llama_deploy.types import ( - CONTROL_PLANE_NAME, - ActionTypes, - ChatMessage, - MessageRole, - ServiceDefinition, - TaskDefinition, - TaskResult, - ToolCall, - ToolCallBundle, - ToolCallResult, -) - -logger = getLogger(__name__) - - -class AgentService(BaseService): - """Agent Service. - - A service that runs an agent locally, processing incoming tasks step-wise in an endless loop. - - Messages are published to the message queue, and the agent processes them in a loop, - finally returning a message with the completed task. - - This AgentService can either be run in a local loop or as a Fast-API server. - - Exposes the following endpoints: - - GET `/`: Home endpoint. - - POST `/process_message`: Process a message. - - POST `/task`: Create a task. - - GET `/messages`: Get messages. - - POST `/toggle_agent_running`: Toggle the agent running state. - - GET `/is_worker_running`: Check if the agent is running. - - POST `/reset_agent`: Reset the agent. - - Since the agent can launch as a FastAPI server, you can visit `/docs` for full swagger documentation. - - Attributes: - service_name (str): - The name of the service. - agent (AgentRunner): - The agent to run. - description (str): - The description of the service. - prompt (Optional[List[ChatMessage]]): - The prompt messages, meant to be appended to the start of tasks (currently TODO). - running (bool): - Whether the agent is running. - step_interval (float): - The interval in seconds to poll for task completion. Defaults to 0.1s. - host (Optional[str]): - The host to launch a FastAPI server on. - port (Optional[int]): - The port to launch a FastAPI server on. - raise_exceptions (bool): - Whether to raise exceptions in the processing loop. - - Examples: - ```python - from llama_deploy import AgentService - from llama_index.core.agent import ReActAgent - - agent = ReActAgent.from_tools([...], llm=llm) - agent_service = AgentService( - agent, - message_queue, - service_name="my_agent_service", - description="My Agent Service", - host="127.0.0.1", - port=8003, - ) - - # launch as a server for remote access or documentation - await agent_service.launch_server() - ``` - """ - - service_name: str - agent: AgentRunner - description: str = "Local Agent Service." - prompt: Optional[List[ChatMessage]] = None - running: bool = True - step_interval: float = 0.1 - host: str - port: int - raise_exceptions: bool = False - - _message_queue: BaseMessageQueue = PrivateAttr() - _app: FastAPI = PrivateAttr() - _publisher_id: str = PrivateAttr() - _publish_callback: Optional[PublishCallback] = PrivateAttr() - _lock: asyncio.Lock = PrivateAttr() - _tasks_as_tool_calls: Dict[str, ToolCall] = PrivateAttr() - - def __init__( - self, - agent: AgentRunner, - message_queue: BaseMessageQueue, - host: str, - port: int, - running: bool = True, - description: str = "Agent Server", - service_name: str = "default_agent", - prompt: Optional[List[ChatMessage]] = None, - publish_callback: Optional[PublishCallback] = None, - step_interval: float = 0.1, - raise_exceptions: bool = False, - ) -> None: - super().__init__( - agent=agent, - running=running, - description=description, - service_name=service_name, - step_interval=step_interval, - prompt=prompt, - host=host, - port=port, - raise_exceptions=raise_exceptions, - ) - - self._lock = asyncio.Lock() - self._tasks_as_tool_calls = {} - self._message_queue = message_queue - self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}" - self._publish_callback = publish_callback - self._app = FastAPI(lifespan=self.lifespan) - - self._app.add_api_route("/", self.home, methods=["GET"], tags=["Agent State"]) - - self._app.add_api_route( - "/process_message", - self.process_message, - methods=["POST"], - tags=["Message Processing"], - ) - - self._app.add_api_route( - "/task", self.create_task, methods=["POST"], tags=["Tasks"] - ) - - self._app.add_api_route( - "/messages", self.get_messages, methods=["GET"], tags=["Agent State"] - ) - self._app.add_api_route( - "/toggle_agent_running", - self.toggle_agent_running, - methods=["POST"], - tags=["Agent State"], - ) - self._app.add_api_route( - "/is_worker_running", - self.is_worker_running, - methods=["GET"], - tags=["Agent State"], - ) - self._app.add_api_route( - "/reset_agent", self.reset_agent, methods=["POST"], tags=["Agent State"] - ) - - @property - def service_definition(self) -> ServiceDefinition: - """The service definition.""" - return ServiceDefinition( - service_name=self.service_name, - description=self.description, - prompt=self.prompt or [], - host=self.host, - port=self.port, - ) - - @property - def message_queue(self) -> BaseMessageQueue: - """The message queue.""" - return self._message_queue - - @property - def publisher_id(self) -> str: - """The publisher id.""" - return self._publisher_id - - @property - def publish_callback(self) -> Optional[PublishCallback]: - """The publish callback, if any.""" - return self._publish_callback - - @property - def lock(self) -> asyncio.Lock: - return self._lock - - @property - def tool_name(self) -> str: - """The name reserved when this service is used as a tool.""" - return get_tool_name_from_service_name(self.service_name) - - async def processing_loop(self) -> None: - """The processing loop for the agent.""" - logger.info("Processing initiated.") - while True: - try: - if not self.running: - await asyncio.sleep(self.step_interval) - continue - - current_tasks = self.agent.list_tasks() - current_task_ids = [task.task_id for task in current_tasks] - - completed_tasks = self.agent.get_completed_tasks() - completed_task_ids = [task.task_id for task in completed_tasks] - - for task_id in current_task_ids: - if task_id in completed_task_ids: - continue - - step_output = await self.agent.arun_step(task_id) - - if step_output.is_last: - # finalize the response - response = self.agent.finalize_response( - task_id, step_output=step_output - ) - - # convert memory chat messages - llama_messages = self.agent.memory.get() - history = [ChatMessage(**x.dict()) for x in llama_messages] - - # publish the completed task - async with self.lock: - try: - tool_call = self._tasks_as_tool_calls.pop(task_id) - except KeyError: - tool_call = None - - if tool_call: - await self.publish( - QueueMessage( - type=tool_call.source_id, - action=ActionTypes.COMPLETED_TOOL_CALL, - data=ToolCallResult( - id_=tool_call.id_, - tool_message=ChatMessage( - content=str(response.response), - role=MessageRole.TOOL, - additional_kwargs={ - "name": tool_call.tool_call_bundle.tool_name, - "tool_call_id": tool_call.id_, - }, - ), - result=response.response, - ).model_dump(), - ) - ) - else: - await self.publish( - QueueMessage( - type=CONTROL_PLANE_NAME, - action=ActionTypes.COMPLETED_TASK, - data=TaskResult( - task_id=task_id, - history=history, - result=response.response, - ).model_dump(), - ) - ) - except Exception as e: - logger.error(f"Error in {self.service_name} processing_loop: {e}") - if self.raise_exceptions: - # Kill everything - # TODO: is there a better way to do this? - import signal - - signal.raise_signal(signal.SIGINT) - else: - await self.message_queue.publish( - QueueMessage( - type=CONTROL_PLANE_MESSAGE_TYPE, - action=ActionTypes.COMPLETED_TASK, - data=TaskResult( - task_id=task_id, - history=[], - result=f"Error during processing: {e}", - ).model_dump(), - ), - topic=self.get_topic(CONTROL_PLANE_MESSAGE_TYPE), - ) - - continue - - await asyncio.sleep(self.step_interval) - - async def process_message(self, message: QueueMessage) -> None: - """Handling for when a message is received.""" - if message.action == ActionTypes.NEW_TASK: - task_def = TaskDefinition(**message.data or {}) - self.agent.create_task(task_def.input, task_id=task_def.task_id) - logger.info(f"Created new task: {task_def.task_id}") - elif message.action == ActionTypes.NEW_TOOL_CALL: - task_def = TaskDefinition(**message.data or {}) - async with self.lock: - tool_call_bundle = ToolCallBundle( - tool_name=self.tool_name, - tool_args=[], - tool_kwargs={"input": task_def.input}, - ) - task_as_tool_call = ToolCall( - id_=task_def.task_id, - source_id=message.publisher_id, - tool_call_bundle=tool_call_bundle, - ) - self._tasks_as_tool_calls[task_def.task_id] = task_as_tool_call - self.agent.create_task(task_def.input, task_id=task_def.task_id) - logger.info(f"Created new tool call as task: {task_def.task_id}") - else: - raise ValueError(f"Unhandled action: {message.action}") - - def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: - """Get the consumer for the message queue. - - Args: - remote (bool): - Whether to get a remote consumer or local. - If remote, calls the `process_message` endpoint. - """ - if remote: - url = ( - f"http://{self.host}:{self.port}{self._app.url_path_for('process_message')}" - if self.port - else f"http://{self.host}{self._app.url_path_for('process_message')}" - ) - return RemoteMessageConsumer( - id_=self.publisher_id, - url=url, - message_type=self.service_name, - ) - - return CallableMessageConsumer( - id_=self.publisher_id, - message_type=self.service_name, - handler=self.process_message, - ) - - async def launch_local(self) -> asyncio.Task: - """Launch the agent locally.""" - logger.info(f"{self.service_name} launch_local") - return asyncio.create_task(self.processing_loop()) - - # ---- Server based methods ---- - - @asynccontextmanager - async def lifespan(self, app: FastAPI) -> AsyncGenerator[None, None]: - """Starts the processing loop when the fastapi app starts.""" - asyncio.create_task(self.processing_loop()) - yield - self.running = False - - async def home(self) -> Dict[str, str]: - """Home endpoint. Gets general information about the agent service.""" - tasks = self.agent.list_tasks() - - task_strings = [] - for task in tasks: - task_output = self.agent.get_task_output(task.task_id) - status = "COMPLETE" if task_output.is_last else "IN PROGRESS" - memory_str = "\n".join( - [f"{x.role}: {x.content}" for x in task.memory.get_all()] - ) - task_strings.append(f"Agent Task {task.task_id}: {status}\n{memory_str}") - - complete_task_string = "\n".join(task_strings) - - return { - "service_name": self.service_name, - "description": self.description, - "running": str(self.running), - "step_interval": str(self.step_interval), - "num_tasks": str(len(tasks)), - "num_completed_tasks": str(len(self.agent.get_completed_tasks())), - "prompt": "\n".join([str(x) for x in self.prompt]) if self.prompt else "", - "type": "agent_service", - "tasks": complete_task_string, - } - - async def create_task(self, task_definition: TaskDefinition) -> Dict[str, str]: - """Create a task.""" - task = self.agent.create_task( - task_definition.input, task_id=task_definition.task_id - ) - return {"task_id": task.task_id} - - async def get_messages(self) -> List[_ChatMessage]: - """Get messages from the agent.""" - messages = self.agent.chat_history - - return [ - _ChatMessage.from_chat_message(cast(ChatMessage, message)) - for message in messages - ] - - async def toggle_agent_running( - self, state: Literal["running", "stopped"] - ) -> Dict[str, bool]: - """Toggle the agent running state.""" - self.running = state == "running" - - return {"running": self.running} - - async def is_worker_running(self) -> Dict[str, bool]: - """Check if the agent is running.""" - return {"running": self.running} - - async def reset_agent(self) -> Dict[str, str]: - """Reset the agent.""" - self.agent.reset() - - return {"message": "Agent reset"} - - async def launch_server(self) -> None: - """Launch the agent as a FastAPI server.""" - logger.info(f"Launching {self.service_name} server at {self.host}:{self.port}") - # uvicorn.run(self._app, host=self.host, port=self.port) - - class CustomServer(uvicorn.Server): - def install_signal_handlers(self) -> None: - pass - - cfg = uvicorn.Config(self._app, host=self.host, port=self.port) - server = CustomServer(cfg) - await server.serve() diff --git a/llama_deploy/services/human.py b/llama_deploy/services/human.py deleted file mode 100644 index 50e32fb3..00000000 --- a/llama_deploy/services/human.py +++ /dev/null @@ -1,449 +0,0 @@ -import asyncio -import uuid -from asyncio import Lock -from asyncio.exceptions import CancelledError -from contextlib import asynccontextmanager -from logging import getLogger -from typing import ( - Any, - AsyncGenerator, - Awaitable, - Dict, - List, - Optional, - Protocol, - runtime_checkable, -) - -import uvicorn -from fastapi import FastAPI -from llama_index.core.llms import MessageRole -from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator - -from llama_deploy.message_consumers.base import BaseMessageQueueConsumer -from llama_deploy.message_consumers.callable import CallableMessageConsumer -from llama_deploy.message_consumers.remote import RemoteMessageConsumer -from llama_deploy.message_publishers.publisher import PublishCallback -from llama_deploy.message_queues.base import BaseMessageQueue -from llama_deploy.messages.base import QueueMessage -from llama_deploy.services.base import BaseService -from llama_deploy.tools.utils import get_tool_name_from_service_name -from llama_deploy.types import ( - CONTROL_PLANE_NAME, - ActionTypes, - ChatMessage, - HumanResponse, - ServiceDefinition, - TaskDefinition, - TaskResult, - ToolCall, - ToolCallBundle, - ToolCallResult, -) -from llama_deploy.utils import get_prompt_params - -logger = getLogger(__name__) - - -HELP_REQUEST_TEMPLATE_STR = ( - "Your assistance is needed. Please respond to the request " - "provided below:\n===\n\n" - "{input_str}\n\n===\n" -) - - -@runtime_checkable -class HumanInputFn(Protocol): - """Protocol for getting human input.""" - - def __call__(self, prompt: str, task_id: str, **kwargs: Any) -> Awaitable[str]: ... - - -async def default_human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str: - del task_id - return input(prompt) - - -class HumanService(BaseService): - """A human service for providing human-in-the-loop assistance. - - When launched locally, it will prompt the user for input, which is blocking! - - When launched as a server, it will provide an API for creating and handling tasks. - - Exposes the following endpoints: - - GET `/`: Get the service information. - - POST `/process_message`: Process a message. - - POST `/tasks`: Create a task. - - GET `/tasks`: Get all tasks. - - GET `/tasks/{task_id}`: Get a task. - - POST `/tasks/{task_id}/handle`: Handle a task. - - Attributes: - service_name (str): The name of the service. - description (str): The description of the service. - running (bool): Whether the service is running. - step_interval (float): The interval in seconds to poll for tool call results. Defaults to 0.1s. - host (Optional[str]): The host of the service. - port (Optional[int]): The port of the service. - - - """ - - model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) - service_name: str - description: str = "Local Human Service." - running: bool = True - step_interval: float = 0.1 - fn_input: HumanInputFn = default_human_input_fn - human_input_prompt: str = ( - HELP_REQUEST_TEMPLATE_STR # TODO: use PromptMixin, PromptTemplate - ) - host: str - port: int - - _outstanding_human_tasks: List["HumanTask"] = PrivateAttr() - _message_queue: BaseMessageQueue = PrivateAttr() - _app: FastAPI = PrivateAttr() - _publisher_id: str = PrivateAttr() - _publish_callback: Optional[PublishCallback] = PrivateAttr() - _lock: Lock = PrivateAttr() - _tasks_as_tool_calls: Dict[str, ToolCall] = PrivateAttr() - - def __init__( - self, - message_queue: BaseMessageQueue, - running: bool = True, - description: str = "Local Human Service", - service_name: str = "default_human_service", - publish_callback: Optional[PublishCallback] = None, - step_interval: float = 0.1, - fn_input: HumanInputFn = default_human_input_fn, - human_input_prompt: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, - ) -> None: - human_input_prompt = human_input_prompt or HELP_REQUEST_TEMPLATE_STR - super().__init__( - running=running, - description=description, - service_name=service_name, - step_interval=step_interval, - fn_input=fn_input, - human_input_prompt=human_input_prompt, - host=host, - port=port, - ) - - self._outstanding_human_tasks = [] - self._message_queue = message_queue - self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}" - self._publish_callback = publish_callback - self._lock = asyncio.Lock() - self._tasks_as_tool_calls = {} - self._app = FastAPI(lifespan=self.lifespan) - - self._app.add_api_route("/", self.home, methods=["GET"], tags=["Human Service"]) - self._app.add_api_route( - "/process_message", - self.process_message, - methods=["POST"], - tags=["Human Service"], - ) - - self._app.add_api_route( - "/tasks", self.create_task, methods=["POST"], tags=["Tasks"] - ) - self._app.add_api_route( - "/tasks", self.get_tasks, methods=["GET"], tags=["Tasks"] - ) - self._app.add_api_route( - "/tasks/{task_id}", self.get_task, methods=["GET"], tags=["Tasks"] - ) - self._app.add_api_route( - "/tasks/{task_id}/handle", - self.handle_task, - methods=["POST"], - tags=["Tasks"], - ) - - @property - def service_definition(self) -> ServiceDefinition: - """Get the service definition.""" - return ServiceDefinition( - service_name=self.service_name, - description=self.description, - prompt=[], - host=self.host, - port=self.port, - ) - - @property - def message_queue(self) -> BaseMessageQueue: - """The message queue.""" - return self._message_queue - - @property - def publisher_id(self) -> str: - """The publisher ID.""" - return self._publisher_id - - @property - def publish_callback(self) -> Optional[PublishCallback]: - """The publish callback, if any.""" - return self._publish_callback - - @property - def lock(self) -> Lock: - return self._lock - - @property - def tool_name(self) -> str: - """The name reserved when this service is used as a tool.""" - return get_tool_name_from_service_name(self.service_name) - - async def processing_loop(self) -> None: - """The processing loop for the service.""" - try: - await self._processing_loop() - except CancelledError: - logger.info("Processing cancelled.") - return - - async def _processing_loop(self) -> None: - logger.info("Processing initiated.") - while True: - if not self.running: - await asyncio.sleep(self.step_interval) - continue - - async with self.lock: - try: - human_task = self._outstanding_human_tasks.pop(0) - task_def = human_task.task_def - tool_call = human_task.tool_call - except IndexError: - await asyncio.sleep(self.step_interval) - continue - - logger.info( - f"Processing request for human help for task: {task_def.task_id}" - ) - - # process req - prompt = ( - self.human_input_prompt.format(input_str=task_def.input) - if self.human_input_prompt - else task_def.input - ) - result = await self.fn_input(prompt=prompt, task_id=task_def.task_id) - - # create history - history = [ - ChatMessage( - role=MessageRole.ASSISTANT, - content=HELP_REQUEST_TEMPLATE_STR.format( - input_str=task_def.input - ), - ), - ChatMessage(role=MessageRole.USER, content=result), - ] - - if tool_call: - await self.publish( - QueueMessage( - type=tool_call.source_id, - action=ActionTypes.COMPLETED_TOOL_CALL, - data=ToolCallResult( - id_=tool_call.id_, - tool_message=ChatMessage( - content=result, - role=MessageRole.TOOL, - additional_kwargs={ - "name": tool_call.tool_call_bundle.tool_name, - "tool_call_id": tool_call.id_, - }, - ), - result=result, - ).model_dump(), - ) - ) - else: - # publish the completed task - await self.publish( - QueueMessage( - type=CONTROL_PLANE_NAME, - action=ActionTypes.COMPLETED_TASK, - data=TaskResult( - task_id=task_def.task_id, - history=history, - result=result, - ).model_dump(), - ) - ) - - await asyncio.sleep(self.step_interval) - - class HumanTask(BaseModel): - """Container for Tasks to be completed by HumanService.""" - - task_def: TaskDefinition - tool_call: Optional[ToolCall] = None - - async def process_message(self, message: QueueMessage) -> None: - """Process a message received from the message queue.""" - if message.action == ActionTypes.NEW_TASK: - task_def = TaskDefinition(**message.data or {}) - human_task = self.HumanTask(task_def=task_def) - logger.info(f"Created new task: {task_def.task_id}") - elif message.action == ActionTypes.NEW_TOOL_CALL: - task_def = TaskDefinition(**message.data or {}) - tool_call_bundle = ToolCallBundle( - tool_name=self.tool_name, - tool_args=[], - tool_kwargs={"input": task_def.input}, - ) - task_as_tool_call = ToolCall( - id_=task_def.task_id, - source_id=message.publisher_id, - tool_call_bundle=tool_call_bundle, - ) - human_task = self.HumanTask(task_def=task_def, tool_call=task_as_tool_call) - logger.info(f"Created new tool call as task: {task_def.task_id}") - else: - raise ValueError(f"Unhandled action: {message.action}") - async with self.lock: - self._outstanding_human_tasks.append(human_task) - - def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: - """Get the consumer for the service. - - Args: - remote (bool): - Whether the consumer is remote. Defaults to False. - If True, the consumer will be a RemoteMessageConsumer that uses the `process_message` endpoint. - """ - if remote: - url = ( - f"http://{self.host}:{self.port}{self._app.url_path_for('process_message')}" - if self.port - else f"http://{self.host}{self._app.url_path_for('process_message')}" - ) - return RemoteMessageConsumer( - id_=self.publisher_id, - url=url, - message_type=self.service_name, - ) - - return CallableMessageConsumer( - id_=self.publisher_id, - message_type=self.service_name, - handler=self.process_message, - ) - - async def launch_local(self) -> asyncio.Task: - """Launch the service in-process.""" - logger.info(f"{self.service_name} launch_local") - return asyncio.create_task(self.processing_loop()) - - # ---- Server based methods ---- - - @asynccontextmanager - async def lifespan(self, app: FastAPI) -> AsyncGenerator[None, None]: - """Starts the processing loop when the fastapi app starts.""" - asyncio.create_task(self.processing_loop()) - yield - self.running = False - - async def home(self) -> Dict[str, str]: - """Get general service information.""" - return { - "service_name": self.service_name, - "description": self.description, - "running": str(self.running), - "step_interval": str(self.step_interval), - "num_tasks": str(len(self._outstanding_human_tasks)), - "tasks": "\n".join([str(task) for task in self._outstanding_human_tasks]), - "type": "human_service", - } - - async def create_task(self, task: TaskDefinition) -> Dict[str, str]: - """Create a task for the human service.""" - async with self.lock: - human_task = self.HumanTask(task_def=task) - self._outstanding_human_tasks.append(human_task) - return {"task_id": task.task_id} - - async def get_tasks(self) -> List[TaskDefinition]: - """Get all outstanding tasks.""" - async with self.lock: - return [ht.task_def for ht in self._outstanding_human_tasks] - - async def get_task(self, task_id: str) -> Optional[TaskDefinition]: - """Get a specific task by ID.""" - async with self.lock: - for human_task in self._outstanding_human_tasks: - if human_task.task_def.task_id == task_id: - return human_task.task_def - return None - - async def handle_task(self, task_id: str, result: HumanResponse) -> None: - """Handle a task by providing a result.""" - async with self.lock: - for human_task in self._outstanding_human_tasks: - task_def = human_task.task_def - if task_def.task_id == task_id: - self._outstanding_human_tasks.remove(human_task) - break - - logger.info(f"Processing request for human help for task: {task_def.task_id}") - - # create history - history = [ - ChatMessage( - role=MessageRole.ASSISTANT, - content=HELP_REQUEST_TEMPLATE_STR.format(input_str=task_def.input), - ), - ChatMessage(role=MessageRole.USER, content=result.result), - ] - - # publish the completed task - await self.publish( - QueueMessage( - type=CONTROL_PLANE_NAME, - action=ActionTypes.COMPLETED_TASK, - data=TaskResult( - task_id=task_def.task_id, - history=history, - result=result.result, - ).model_dump(), - ) - ) - - async def launch_server(self) -> None: - """Launch the service as a FastAPI server.""" - logger.info( - f"Lanching server for {self.service_name} at {self.host}:{self.port}" - ) - - class CustomServer(uvicorn.Server): - def install_signal_handlers(self) -> None: - pass - - cfg = uvicorn.Config(self._app, host=self.host, port=self.port) - server = CustomServer(cfg) - await server.serve() - - @field_validator("human_input_prompt") - @classmethod - def validate_human_input_prompt(cls, v: str) -> str: - """Check if `input_str` is a prompt key.""" - prompt_params = get_prompt_params(v) - if "input_str" not in prompt_params: - raise ValueError( - "`input_str` should be the only param in the prompt template." - ) - return v - - -HumanService.model_rebuild() diff --git a/llama_deploy/tools/__init__.py b/llama_deploy/tools/__init__.py deleted file mode 100644 index 5233922d..00000000 --- a/llama_deploy/tools/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from llama_deploy.tools.agent_service_tool import AgentServiceTool -from llama_deploy.tools.meta_service_tool import MetaServiceTool -from llama_deploy.tools.service_as_tool import ServiceAsTool -from llama_deploy.tools.service_tool import ServiceTool -from llama_deploy.tools.service_component import ServiceComponent - - -__all__ = [ - "AgentServiceTool", - "MetaServiceTool", - "ServiceAsTool", - "ServiceTool", - "ServiceComponent", -] diff --git a/llama_deploy/tools/agent_service_tool.py b/llama_deploy/tools/agent_service_tool.py deleted file mode 100644 index 06408da4..00000000 --- a/llama_deploy/tools/agent_service_tool.py +++ /dev/null @@ -1,4 +0,0 @@ -from llama_deploy.tools.service_as_tool import ServiceAsTool - -# NOTE: for backwards compatibility -AgentServiceTool = ServiceAsTool diff --git a/llama_deploy/tools/meta_service_tool.py b/llama_deploy/tools/meta_service_tool.py deleted file mode 100644 index b4a6ddf1..00000000 --- a/llama_deploy/tools/meta_service_tool.py +++ /dev/null @@ -1,274 +0,0 @@ -import asyncio -import uuid -from logging import getLogger -from typing import Any, Dict, Optional - -from llama_index.core.tools import AsyncBaseTool, ToolMetadata, ToolOutput -from pydantic import BaseModel, Field, PrivateAttr - -from llama_deploy.message_consumers.base import BaseMessageQueueConsumer -from llama_deploy.message_consumers.callable import CallableMessageConsumer -from llama_deploy.message_publishers.publisher import ( - MessageQueuePublisherMixin, - PublishCallback, -) -from llama_deploy.message_queues.base import BaseMessageQueue -from llama_deploy.messages.base import QueueMessage -from llama_deploy.services.tool import ToolService -from llama_deploy.types import ( - ActionTypes, - ToolCall, - ToolCallBundle, - ToolCallResult, -) - -logger = getLogger(__name__) - - -class MetaServiceTool(MessageQueuePublisherMixin, AsyncBaseTool, BaseModel): - """A tool that uses a service to perform a task. - - When a tool call is made, this tool forwards the call to a service for execution. - This enables async/non-blocking tool calls. - - Attributes: - tool_call_results (Dict[str, ToolCallResult]): - A dictionary of tool call results. - timeout (float): - The timeout interval in seconds. - tool_service_name (str): - The name of the tool service. - step_interval (float): - The interval in seconds to poll for tool call results. - raise_timeout (bool): - Whether to raise a TimeoutError when the timeout is reached. - registered (bool): - Whether the tool is registered to the message queue. - - Examples: - ```python - from llama_deploy import SimpleMessageQueue - from llama_deploy.tools import MetaServiceTool - from llama_index.core.tools import ToolMetadata - - message_queue = SimpleMessageQueue() - tool_metadata = ToolMetadata(name="my_tool") - tool = MetaServiceTool( - tool_metadata=tool_metadata, - message_queue=message_queue, - tool_service_name="my_tool_service", - ) - result = await tool.acall("arg1", kwarg1="value1") - print(result) - """ - - tool_call_results: Dict[str, ToolCallResult] = Field(default_factory=dict) - timeout: float = Field(default=10.0, description="timeout interval in seconds.") - tool_service_name: str = Field(default_factory=str) - step_interval: float = 0.1 - raise_timeout: bool = False - registered: bool = False - - _message_queue: BaseMessageQueue = PrivateAttr() - _publisher_id: str = PrivateAttr() - _publish_callback: Optional[PublishCallback] = PrivateAttr() - _lock: asyncio.Lock = PrivateAttr() - _metadata: ToolMetadata = PrivateAttr() - - def __init__( - self, - tool_metadata: ToolMetadata, - message_queue: BaseMessageQueue, - tool_service_name: str, - publish_callback: Optional[PublishCallback] = None, - tool_call_results: Dict[str, ToolCallResult] = {}, - timeout: float = 10.0, - step_interval: float = 0.1, - raise_timeout: bool = False, - ) -> None: - super().__init__( - tool_call_results=tool_call_results, - timeout=timeout, - step_interval=step_interval, - tool_service_name=tool_service_name, - raise_timeout=raise_timeout, - ) - self._message_queue = message_queue - self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}" - self._publish_callback = publish_callback - self._metadata = tool_metadata - self._lock = asyncio.Lock() - - @classmethod - async def from_tool_service( - cls, - name: str, - message_queue: BaseMessageQueue, - tool_service: Optional[ToolService] = None, - tool_service_url: Optional[str] = None, - tool_service_api_key: Optional[str] = None, - tool_service_name: Optional[str] = None, - publish_callback: Optional[PublishCallback] = None, - timeout: float = 10.0, - step_interval: float = 0.1, - raise_timeout: bool = False, - ) -> "MetaServiceTool": - if tool_service is not None: - res = await tool_service.get_tool_by_name(name) - try: - tool_metadata = res["tool_metadata"] - except KeyError: - raise ValueError("tool_metadata not found.") - return cls( - tool_metadata=tool_metadata, - message_queue=message_queue, - tool_service_name=tool_service.service_name, - publish_callback=publish_callback, - timeout=timeout, - step_interval=step_interval, - raise_timeout=raise_timeout, - ) - # TODO by requests - # make a http request, try to parse into BaseTool - elif ( - tool_service_url is not None - and tool_service_api_key is not None - and tool_service_name is not None - ): - return cls( - tool_service_name=tool_service_name, - tool_metadata=ToolMetadata("TODO"), - message_queue=message_queue, - publish_callback=publish_callback, - ) - else: - raise ValueError( - "Please supply either a ToolService or a triplet of {tool_service_url, tool_service_api_key, tool_service_name}." - ) - - @property - def message_queue(self) -> BaseMessageQueue: - return self._message_queue - - @property - def publisher_id(self) -> str: - return self._publisher_id - - @property - def publish_callback(self) -> Optional[PublishCallback]: - return self._publish_callback - - @property - def metadata(self) -> ToolMetadata: - return self._metadata - - @property - def lock(self) -> asyncio.Lock: - return self._lock - - async def process_message(self, message: QueueMessage, **kwargs: Any) -> None: - if message.action == ActionTypes.COMPLETED_TOOL_CALL: - tool_call_result = ToolCallResult(**message.data or {}) - async with self.lock: - self.tool_call_results.update({tool_call_result.id_: tool_call_result}) - else: - raise ValueError(f"Unhandled action: {message.action}") - - def as_consumer(self) -> BaseMessageQueueConsumer: - return CallableMessageConsumer( - id_=self.publisher_id, - message_type=self.publisher_id, - handler=self.process_message, - ) - - async def purge_old_tool_call_results(self, cutoff_date: str) -> None: - """Purge old tool call results. - - TODO: implement this. - """ - pass - - async def _poll_for_tool_call_result(self, tool_call_id: str) -> ToolCallResult: - tool_call_result = None - while tool_call_result is None: - async with self.lock: - tool_call_result = ( - self.tool_call_results[tool_call_id] - if tool_call_id in self.tool_call_results - else None - ) - - await asyncio.sleep(self.step_interval) - return tool_call_result - - async def deregister(self) -> None: - """Deregister from message queue.""" - await self.message_queue.deregister_consumer(self.as_consumer()) - self.registered = False - - def call(self, *args: Any, **kwargs: Any) -> ToolOutput: - """Call.""" - return asyncio.run(self.acall(*args, **kwargs)) - - async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput: - """Publish a call to the queue. - - In order to get a ToolOutput result, this will poll the queue until - the result is written. - """ - if not self.registered: - # register tool to message queue - await self.message_queue.register_consumer(self.as_consumer()) - self.registered = True - - tool_call = ToolCall( - tool_call_bundle=ToolCallBundle( - tool_name=self.metadata.name or "", - tool_args=list(args), - tool_kwargs=kwargs, - ), - source_id=self.publisher_id, - ) - await self.publish( - QueueMessage( - type=self.tool_service_name, - action=ActionTypes.NEW_TOOL_CALL, - data=tool_call.model_dump(), - ) - ) - - # poll for tool_call_result with max timeout - try: - tool_call_result = await asyncio.wait_for( - self._poll_for_tool_call_result(tool_call_id=tool_call.id_), - timeout=self.timeout, - ) - except ( - asyncio.exceptions.TimeoutError, - asyncio.TimeoutError, - TimeoutError, - ) as e: - logger.debug(f"Timeout reached for tool_call with id {tool_call.id_}") - if self.raise_timeout: - raise - return ToolOutput( - content="Encountered error: " + str(e), - tool_name=self.metadata.name or "", - raw_input={"args": args, "kwargs": kwargs}, - raw_output=str(e), - is_error=True, - ) - finally: - async with self.lock: - if tool_call.id_ in self.tool_call_results: - del self.tool_call_results[tool_call.id_] - - return ToolOutput( - content=tool_call_result.result, - tool_name=self.metadata.name or "", - raw_input={"args": args, "kwargs": kwargs}, - raw_output=tool_call_result.result, - ) - - def get_topic(self, msg_type: str) -> str: - return msg_type diff --git a/llama_deploy/tools/service_as_tool.py b/llama_deploy/tools/service_as_tool.py deleted file mode 100644 index 936604a9..00000000 --- a/llama_deploy/tools/service_as_tool.py +++ /dev/null @@ -1,287 +0,0 @@ -import asyncio -import logging -import uuid -from typing import Any, Dict, Optional - -from llama_index.core.tools import AsyncBaseTool, ToolMetadata, ToolOutput -from pydantic import BaseModel, Field, PrivateAttr - -from llama_deploy.message_consumers.base import BaseMessageQueueConsumer -from llama_deploy.message_consumers.callable import CallableMessageConsumer -from llama_deploy.message_publishers.publisher import ( - MessageQueuePublisherMixin, - PublishCallback, -) -from llama_deploy.message_queues.base import BaseMessageQueue -from llama_deploy.messages.base import QueueMessage -from llama_deploy.tools.utils import get_tool_name_from_service_name -from llama_deploy.types import ( - ActionTypes, - ServiceDefinition, - TaskDefinition, - ToolCallResult, -) - -logger = logging.getLogger(__name__) - - -class ServiceAsTool(MessageQueuePublisherMixin, AsyncBaseTool, BaseModel): - """Service As Tool. - - This class is a wrapper around any BaseService, providing a tool-like interface, - to be used as a tool in any other llama-index abstraction. - - NOTE: The BaseService must be able to process messages with action type: NEW_TOOL_CALL - """ - - tool_call_results: Dict[str, ToolCallResult] = Field(default_factory=dict) - timeout: float = Field(default=10.0, description="timeout interval in seconds.") - service_name: str = Field(default_factory=str) - step_interval: float = 0.1 - raise_timeout: bool = False - registered: bool = False - - _message_queue: BaseMessageQueue = PrivateAttr() - _publisher_id: str = PrivateAttr() - _publish_callback: Optional[PublishCallback] = PrivateAttr() - _lock: asyncio.Lock = PrivateAttr() - _metadata: ToolMetadata = PrivateAttr() - - def __init__( - self, - tool_metadata: ToolMetadata, - message_queue: BaseMessageQueue, - service_name: str, - publish_callback: Optional[PublishCallback] = None, - tool_call_results: Dict[str, ToolCallResult] = {}, - timeout: float = 60.0, - step_interval: float = 0.1, - raise_timeout: bool = False, - ) -> None: - """Class constructor. - - Args: - tool_metadata (ToolMetadata): Tool metadata. - message_queue (BaseMessageQueue): Message queue. - service_name (str): Service name. - publish_callback (Optional[PublishCallback], optional): Publish callback. Defaults to None. - tool_call_results (Dict[str, ToolCallResult], optional): Tool call results. Defaults to {}. - timeout (float, optional): Timeout. Defaults to 60.0s. - step_interval (float, optional): Step interval when polling for a result. Defaults to 0.1s. - raise_timeout (bool, optional): Raise timeout. Defaults to False. - - Examples: - ```python - from llama_deploy import AgentService, ServiceAsTool, SimpleMessageQueue - - message_queue = SimpleMessageQueue() - - agent1_server = AgentService( - agent=agent1, - message_queue=message_queue, - description="Useful for getting the secret fact.", - service_name="secret_fact_agent", - ) - - # create the tool for use in other agents - agent1_server_tool = ServiceAsTool.from_service_definition( - message_queue=message_queue, - service_definition=agent1_server.service_definition - ) - - # can also use the tool directly - result = await agent1_server_tool.acall(input="get the secret fact") - print(result) - ``` - """ - # validate fn_schema - if "input" not in tool_metadata.get_parameters_dict()["properties"]: - raise ValueError("Invalid FnSchema - 'input' field is required.") - - # validate tool name - if tool_metadata.name != get_tool_name_from_service_name(service_name): - raise ValueError("Tool name must be in the form '{{service_name}}-as-tool'") - - super().__init__( - tool_call_results=tool_call_results, - timeout=timeout, - step_interval=step_interval, - service_name=service_name, - raise_timeout=raise_timeout, - ) - self._message_queue = message_queue - self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}" - self._publish_callback = publish_callback - self._metadata = tool_metadata - self._lock = asyncio.Lock() - - @classmethod - def from_service_definition( - cls, - message_queue: BaseMessageQueue, - service_definition: ServiceDefinition, - publish_callback: Optional[PublishCallback] = None, - timeout: float = 60.0, - step_interval: float = 0.1, - raise_timeout: bool = False, - ) -> "ServiceAsTool": - """Create an ServiceAsTool from a ServiceDefinition. - - Args: - message_queue (BaseMessageQueue): Message queue. - service_definition (ServiceDefinition): Service definition. - publish_callback (Optional[PublishCallback], optional): Publish callback. Defaults to None. - timeout (float, optional): Timeout. Defaults to 60.0s. - step_interval (float, optional): Step interval. Defaults to 0.1s. - raise_timeout (bool, optional): Raise timeout. Defaults to False. - """ - tool_metadata = ToolMetadata( - description=service_definition.description, - name=get_tool_name_from_service_name(service_definition.service_name), - ) - return cls( - tool_metadata=tool_metadata, - message_queue=message_queue, - service_name=service_definition.service_name, - publish_callback=publish_callback, - timeout=timeout, - step_interval=step_interval, - raise_timeout=raise_timeout, - ) - - @property - def message_queue(self) -> BaseMessageQueue: - """The message queue used by the tool service.""" - return self._message_queue - - @property - def publisher_id(self) -> str: - """The publisher ID.""" - return self._publisher_id - - @property - def publish_callback(self) -> Optional[PublishCallback]: - """The publish callback, if any.""" - return self._publish_callback - - @property - def metadata(self) -> ToolMetadata: - """The tool metadata.""" - return self._metadata - - @property - def lock(self) -> asyncio.Lock: - return self._lock - - async def process_message(self, message: QueueMessage, **kwargs: Any) -> None: - """Process a message from the message queue.""" - if message.action == ActionTypes.COMPLETED_TOOL_CALL: - tool_call_result = ToolCallResult(**message.data or {}) - async with self.lock: - self.tool_call_results.update({tool_call_result.id_: tool_call_result}) - else: - raise ValueError(f"Unhandled action: {message.action}") - - def as_consumer(self) -> BaseMessageQueueConsumer: - """Return a message queue consumer for this tool.""" - return CallableMessageConsumer( - id_=self.publisher_id, - message_type=self.publisher_id, - handler=self.process_message, - ) - - async def purge_old_tool_call_results(self, cutoff_date: str) -> None: - """Purge old tool call results. - - TODO: implement this. - """ - pass - - async def _poll_for_tool_call_result(self, tool_call_id: str) -> ToolCallResult: - tool_call_result = None - while tool_call_result is None: - async with self.lock: - tool_call_result = ( - self.tool_call_results[tool_call_id] - if tool_call_id in self.tool_call_results - else None - ) - - await asyncio.sleep(self.step_interval) - return tool_call_result - - async def deregister(self) -> None: - """Deregister from message queue.""" - await self.message_queue.deregister_consumer(self.as_consumer()) - self.registered = False - - def call(self, *args: Any, **kwargs: Any) -> ToolOutput: - """Publish a call to the queue. - - In order to get a ToolOutput result, this will poll the queue until - the result is written. - """ - return asyncio.run(self.acall(*args, **kwargs)) - - def _parse_args(self, *args: Any, **kwargs: Any) -> str: - return kwargs.pop("input") - - async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput: - """Publish a call to the queue. - - In order to get a ToolOutput result, this will poll the queue until - the result is written. - """ - if not self.registered: - # register tool to message queue - start_consuming_callable = await self.message_queue.register_consumer( - self.as_consumer() - ) - _ = asyncio.create_task(start_consuming_callable()) - self.registered = True - - input = self._parse_args(*args, **kwargs) - task_def = TaskDefinition(input=input) - await self.publish( - QueueMessage( - type=self.service_name, - action=ActionTypes.NEW_TOOL_CALL, - data=task_def.model_dump(), - ) - ) - - # poll for tool_call_result with max timeout - try: - tool_call_result = await asyncio.wait_for( - self._poll_for_tool_call_result(tool_call_id=task_def.task_id), - timeout=self.timeout, - ) - except ( - asyncio.exceptions.TimeoutError, - asyncio.TimeoutError, - TimeoutError, - ) as e: - logger.debug(f"Timeout reached for tool_call with id {task_def.task_id}") - if self.raise_timeout: - raise - return ToolOutput( - content="Encountered error: " + str(e), - tool_name=self.metadata.name or "", - raw_input={"args": args, "kwargs": kwargs}, - raw_output=str(e), - is_error=True, - ) - finally: - async with self.lock: - if task_def.task_id in self.tool_call_results: - del self.tool_call_results[task_def.task_id] - - return ToolOutput( - content=tool_call_result.result, - tool_name=self.metadata.name or "", - raw_input={"args": args, "kwargs": kwargs}, - raw_output=tool_call_result.result, - ) - - def get_topic(self, msg_type: str) -> str: - return msg_type diff --git a/llama_deploy/tools/service_component.py b/llama_deploy/tools/service_component.py deleted file mode 100644 index 9c9a0b76..00000000 --- a/llama_deploy/tools/service_component.py +++ /dev/null @@ -1,142 +0,0 @@ -import json -from typing import Any, Dict, Optional - -from llama_index.core.bridge.pydantic import PrivateAttr -from llama_index.core.query_pipeline import CustomQueryComponent -from llama_index.core.base.query_pipeline.query import InputKeys - -from llama_deploy.types import ServiceDefinition -from enum import Enum - - -class ModuleType(str, Enum): - """Module types. - - Can be either an agent or a component. - - NOTE: this is to allow both agent services and component services to be stitched together with - the pipeline orchestrator. - - Ideally there should not be more types. - - """ - - AGENT = "agent" - COMPONENT = "component" - - -class ServiceComponent(CustomQueryComponent): - """Service component. - - This wraps a service into a component that can be used in a query pipeline. - - Attributes: - name (str): The name of the service. - description (str): The description of the service. - input_keys (Optional[InputKeys]): The input keys. Defaults to a single `input` key. - module_type (ModuleType): The module type. Defaults to `ModuleType.AGENT`. - - Examples: - ```python - from llama_deploy import ServiceComponent, AgentService - - rag_agent_server = AgentService( - agent=rag_agent, - message_queue=message_queue, - description="rag_agent", - ) - rag_agent_server_c = ServiceComponent.from_service_definition( - rag_agent_server.service_definition - ) - - pipeline = QueryPipeline(chain=[rag_agent_server_c]) - ``` - """ - - name: str - description: str - - # Store a set of input keys from upstream modules - # NOTE: no need to track the output keys, this is a fake module anyways - _cur_input_keys: InputKeys = PrivateAttr() - - module_type: ModuleType = ModuleType.AGENT - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._cur_input_keys = kwargs.get("input_keys") or InputKeys.from_keys( - {"input"} - ) - - @classmethod - def from_service_definition( - cls, - service_def: ServiceDefinition, - input_keys: Optional[InputKeys] = None, - module_type: ModuleType = ModuleType.AGENT, - ) -> "ServiceComponent": - """Create a service component from a service definition.""" - return cls( - name=service_def.service_name, - description=service_def.description, - input_keys=input_keys, - module_type=module_type, - ) - - @classmethod - def from_component_service( - cls, - component_service: Any, - ) -> "ServiceComponent": - """Create a service component from a component service.""" - from llama_deploy.services.component import ComponentService - - if not isinstance(component_service, ComponentService): - raise ValueError("component_service must be a Component") - - component = component_service.component - return cls.from_service_definition( - component_service.service_definition, - input_keys=component.input_keys, - module_type=ModuleType.COMPONENT, - ) - - @property - def input_keys(self) -> InputKeys: - """Input keys.""" - # NOTE: user can override this too, but we have them implement an - # abstract method to make sure they do it - - return self._cur_input_keys - - @property - def _input_keys(self) -> set: - """Input keys dict.""" - # HACK: not used - return set() - - @property - def _output_keys(self) -> set: - return {"service_output"} - - def _run_component(self, **kwargs: Any) -> Dict[str, Any]: - """Return a dummy output.""" - json_dump = json.dumps( - { - "name": self.name, - "description": self.description, - "input": kwargs, - } - ) - return {"service_output": json_dump} - - async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]: - """Return a dummy output.""" - json_dump = json.dumps( - { - "name": self.name, - "description": self.description, - "input": kwargs, - } - ) - return {"service_output": json_dump} diff --git a/llama_deploy/tools/service_tool.py b/llama_deploy/tools/service_tool.py deleted file mode 100644 index 6c66188b..00000000 --- a/llama_deploy/tools/service_tool.py +++ /dev/null @@ -1,45 +0,0 @@ -from llama_index.core.tools import AsyncBaseTool, ToolMetadata, ToolOutput - -from llama_deploy.types import ServiceDefinition - - -class ServiceTool(AsyncBaseTool): - """A tool that wraps a service. - - Mostly used under the hood by the agent orchestrator. - - Attributes: - name (str): - The name of the tool. - description (str): - The description of the tool. - """ - - def __init__(self, name: str, description: str) -> None: - self.name = name - self.description = description - - @classmethod - def from_service_definition(cls, service_def: ServiceDefinition) -> "ServiceTool": - return cls(service_def.service_name, service_def.description) - - @property - def metadata(self) -> ToolMetadata: - return ToolMetadata( - name=self.name, - description=self.description, - ) - - def _make_dummy_output(self, input: str) -> ToolOutput: - return ToolOutput( - content=input, - tool_name=self.name, - raw_input={"input": input}, - raw_output=input, - ) - - def call(self, input: str) -> ToolOutput: - return self._make_dummy_output(input) - - async def acall(self, input: str) -> ToolOutput: - return self._make_dummy_output(input) diff --git a/llama_deploy/tools/utils.py b/llama_deploy/tools/utils.py deleted file mode 100644 index fcae8d9a..00000000 --- a/llama_deploy/tools/utils.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Utility functions for tools.""" - - -def get_tool_name_from_service_name(service_name: str) -> str: - """Utility function for getting the reserved name of a tool derived by a service.""" - return f"{service_name}-as-tool" diff --git a/pyproject.toml b/pyproject.toml index 9844ad8d..56496e62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,7 @@ omit = [ "tests/*", # deprecated modules "llama_deploy/client/async_client.py", - "llama_deploy/client/sync_client.py", - "llama_deploy/tools/*" + "llama_deploy/client/sync_client.py" ] [tool.poetry] diff --git a/tests/services/test_agent_service.py b/tests/services/test_agent_service.py deleted file mode 100644 index 246c000b..00000000 --- a/tests/services/test_agent_service.py +++ /dev/null @@ -1,23 +0,0 @@ -from llama_index.core.agent import ReActAgent -from llama_index.core.llms import MockLLM - -from llama_deploy.message_queues.simple import SimpleMessageQueueServer -from llama_deploy.services import AgentService - - -def test_init() -> None: - agent = ReActAgent.from_tools([], llm=MockLLM()) - server = AgentService( - agent, - SimpleMessageQueueServer(), # type:ignore - running=False, - description="Test Agent Server", - step_interval=0.5, - host="localhost", - port=8001, - ) - - assert server.agent == agent - assert server.running is False - assert server.description == "Test Agent Server" - assert server.step_interval == 0.5 diff --git a/tests/services/test_human_service.py b/tests/services/test_human_service.py deleted file mode 100644 index 3eed2d8e..00000000 --- a/tests/services/test_human_service.py +++ /dev/null @@ -1,311 +0,0 @@ -import asyncio -from typing import Any, List -from unittest.mock import MagicMock, patch - -import pytest -from pydantic import PrivateAttr, ValidationError - -from llama_deploy.message_consumers.base import BaseMessageQueueConsumer -from llama_deploy.message_queues.simple import SimpleMessageQueue -from llama_deploy.messages.base import QueueMessage -from llama_deploy.services import HumanService -from llama_deploy.services.human import HELP_REQUEST_TEMPLATE_STR -from llama_deploy.types import ( - CONTROL_PLANE_NAME, - ActionTypes, - ChatMessage, - TaskDefinition, -) - - -class MockMessageConsumer(BaseMessageQueueConsumer): - processed_messages: List[QueueMessage] = [] - _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) - - async def _process_message(self, message: QueueMessage, **kwargs: Any) -> None: - async with self._lock: - self.processed_messages.append(message) - - -@pytest.fixture() -def human_output_consumer() -> MockMessageConsumer: - return MockMessageConsumer(message_type=CONTROL_PLANE_NAME) - - -@pytest.mark.asyncio() -async def test_init() -> None: - # arrange - # act - human_service = HumanService( - message_queue=SimpleMessageQueue(), # type:ignore - running=False, - description="Test Human Service", - service_name="Test Human Service", - step_interval=0.5, - host="localhost", - port=8001, - ) - - # assert - assert not human_service.running - assert human_service.description == "Test Human Service" - assert human_service.service_name == "Test Human Service" - assert human_service.step_interval == 0.5 - - -def test_invalid_human_prompt_raises_validation_error() -> None: - # arrange - invalid_human_prompt_input_str = "{incorrect_param}" - human_service = HumanService( - message_queue=SimpleMessageQueue(), # type: ignore - host="localhost", - port=8001, - ) - - # act/assert - with pytest.raises(ValidationError): - # using invalid prompt at construction should fail - _ = HumanService( - human_input_prompt=invalid_human_prompt_input_str, - message_queue=SimpleMessageQueue(), # type: ignore - ) - with pytest.raises(ValueError): - # updating prompt should fail - human_service.human_input_prompt = invalid_human_prompt_input_str - - -@pytest.mark.asyncio() -@patch("llama_deploy.types.core.uuid") -async def test_create_task(mock_uuid: MagicMock) -> None: - # arrange - human_service = HumanService( - message_queue=SimpleMessageQueue(), # type: ignore - running=False, - description="Test Human Service", - service_name="Test Human Service", - step_interval=0.5, - host="localhost", - port=8001, - ) - mock_uuid.uuid4.return_value = "mock_id" - task = TaskDefinition(task_id="1", input="Mock human req.") - - # act - result = await human_service.create_task(task) - - # assert - assert result == {"task_id": task.task_id} - assert human_service._outstanding_human_tasks[0].task_def == task - - -@pytest.mark.asyncio() -@patch("builtins.input") -async def test_process_task( - mock_input: MagicMock, - human_output_consumer: MockMessageConsumer, - message_queue_server: Any, -) -> None: - # arrange - mq = SimpleMessageQueue() - human_service = HumanService( - message_queue=mq, # type: ignore - host="localhost", - port=8001, - ) - - consumer_fn = await mq.register_consumer( - human_output_consumer, topic="llama_deploy.control_plane" - ) - consumer_task = asyncio.create_task(consumer_fn()) - service_task = asyncio.create_task(human_service.processing_loop()) - await asyncio.sleep(0.5) - mock_input.return_value = "Test human input." - - # act - req = TaskDefinition(task_id="1", input="Mock human req.") - result = await human_service.create_task(req) - await asyncio.sleep(0.5) - - # tear down - consumer_task.cancel() - service_task.cancel() - await asyncio.gather(consumer_task, service_task) - - # assert - mock_input.assert_called_once() - mock_input.assert_called_with( - HELP_REQUEST_TEMPLATE_STR.format(input_str="Mock human req.") - ) - assert len(human_output_consumer.processed_messages) == 1 - assert ( - human_output_consumer.processed_messages[0].data.get("result") - == "Test human input." - ) - assert human_output_consumer.processed_messages[0].data.get("task_id") == "1" - assert result == {"task_id": req.task_id} - assert len(human_service._outstanding_human_tasks) == 0 - - -@pytest.mark.asyncio() -@patch("builtins.input") -async def test_process_human_req_from_queue( - mock_input: MagicMock, - human_output_consumer: MockMessageConsumer, - message_queue_server: Any, -) -> None: - # arrange - mq = SimpleMessageQueue() - human_service = HumanService( - message_queue=mq, # type: ignore - service_name="test_human_service", - host="localhost", - port=8001, - ) - - consumer_fn = await mq.register_consumer( - human_output_consumer, topic="llama_deploy.control_plane" - ) - consumer_task = asyncio.create_task(consumer_fn()) - - service_task = asyncio.create_task(human_service.processing_loop()) - service_consumer_fn = await mq.register_consumer( - human_service.as_consumer(), topic="test_human_service" - ) - service_consumer_task = asyncio.create_task(service_consumer_fn()) - await asyncio.sleep(0.5) - mock_input.return_value = "Test human input." - - # act - req = TaskDefinition(task_id="1", input="Mock human req.") - human_req_message = QueueMessage( - data=req.model_dump(), - action=ActionTypes.NEW_TASK, - type="test_human_service", - ) - await mq.publish(human_req_message, topic="test_human_service") - await asyncio.sleep(0.5) - - # tear down - consumer_task.cancel() - service_task.cancel() - service_consumer_task.cancel() - await asyncio.gather(consumer_task, service_task, service_consumer_task) - - # assert - assert human_service.message_queue == mq - assert len(human_output_consumer.processed_messages) == 1 - assert ( - human_output_consumer.processed_messages[0].data.get("result") - == "Test human input." - ) - assert human_output_consumer.processed_messages[0].data.get("task_id") == "1" - assert len(human_service._outstanding_human_tasks) == 0 - - -@pytest.mark.asyncio() -async def test_process_task_with_custom_human_input_fn( - human_output_consumer: MockMessageConsumer, message_queue_server: Any -) -> None: - # arrange - mq = SimpleMessageQueue() - - async def my_custom_human_input_fn(prompt: str, task_id: str, **kwargs: Any) -> str: - return " ".join([prompt, prompt[::-1]]) - - human_service = HumanService( - message_queue=mq, # type:ignore - fn_input=my_custom_human_input_fn, - human_input_prompt="{input_str}", - host="localhost", - port=8001, - ) - - consumer_fn = await mq.register_consumer( - human_output_consumer, topic="llama_deploy.control_plane" - ) - consumer_task = asyncio.create_task(consumer_fn()) - service_task = asyncio.create_task(human_service.processing_loop()) - await asyncio.sleep(0.5) - - # act - req = TaskDefinition(task_id="1", input="Mock human req.") - result = await human_service.create_task(req) - await asyncio.sleep(0.5) - - # tear down - consumer_task.cancel() - service_task.cancel() - await asyncio.gather(consumer_task, service_task) - - # assert - assert len(human_output_consumer.processed_messages) == 1 - assert ( - human_output_consumer.processed_messages[0].data.get("result") - == "Mock human req. .qer namuh kcoM" - ) - assert human_output_consumer.processed_messages[0].data.get("task_id") == "1" - assert result == {"task_id": req.task_id} - assert len(human_service._outstanding_human_tasks) == 0 - - -@pytest.mark.asyncio() -@patch("builtins.input") -async def test_process_task_as_tool_call( - mock_input: MagicMock, - message_queue_server: Any, -) -> None: - # arrange - mq = SimpleMessageQueue() - human_service = HumanService( - message_queue=mq, # type: ignore - service_name="test_human_service", - host="localhost", - port=8001, - ) - output_consumer = MockMessageConsumer(message_type="tool_call_source") - consumer_fn = await mq.register_consumer( - output_consumer, topic="llama_deploy.tool_call_source" - ) - consumer_task = asyncio.create_task(consumer_fn()) - - service_task = asyncio.create_task(human_service.processing_loop()) - service_consumer_fn = await mq.register_consumer( - human_service.as_consumer(), topic="test_human_service" - ) - service_consumer_task = asyncio.create_task(service_consumer_fn()) - await asyncio.sleep(0.5) - - mock_input.return_value = "Test human input." - - # act - req = TaskDefinition(task_id="1", input="Mock human req.") - human_req_message = QueueMessage( - publisher_id="tool_call_source", - data=req.model_dump(), - action=ActionTypes.NEW_TOOL_CALL, - type="test_human_service", - ) - await mq.publish(human_req_message, topic="test_human_service") - await asyncio.sleep(0.5) - - # tear down - consumer_task.cancel() - service_task.cancel() - service_consumer_task.cancel() - await asyncio.gather(consumer_task, service_task, service_consumer_task) - - # assert - assert human_service.tool_name == "test_human_service-as-tool" - assert len(output_consumer.processed_messages) == 1 - assert ( - output_consumer.processed_messages[0].data.get("result") == "Test human input." - ) - try: - tool_message = ChatMessage.model_validate( - output_consumer.processed_messages[0].data.get("tool_message") - ) - assert tool_message.role == "tool" - except ValidationError: - pytest.fail("Unable to parse result into a ChatMessage object.") - assert output_consumer.processed_messages[0].data.get("id_") == "1" - assert len(human_service._outstanding_human_tasks) == 0 diff --git a/tests/tools/test_agent_service_as_tool.py b/tests/tools/test_agent_service_as_tool.py deleted file mode 100644 index 88db3564..00000000 --- a/tests/tools/test_agent_service_as_tool.py +++ /dev/null @@ -1,218 +0,0 @@ -import asyncio -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from llama_index.core.agent import AgentChatResponse, ReActAgent -from llama_index.core.agent.types import Task, TaskStep, TaskStepOutput -from llama_index.core.llms import MockLLM -from llama_index.core.memory import ChatMemoryBuffer -from llama_index.core.tools import FunctionTool - -from llama_deploy.message_queues.simple import SimpleMessageQueueServer -from llama_deploy.services.agent import AgentService -from llama_deploy.tools import ServiceAsTool - -pytestmark = pytest.mark.skip - - -@pytest.fixture() -def message_queue() -> SimpleMessageQueueServer: - return SimpleMessageQueueServer() - - -@pytest.fixture() -def agent_service(message_queue: SimpleMessageQueueServer) -> AgentService: - # create an agent - def get_the_secret_fact() -> str: - """Returns the secret fact.""" - return "The secret fact is: A baby llama is called a 'Cria'." - - tool = FunctionTool.from_defaults(fn=get_the_secret_fact) - - agent = ReActAgent.from_tools([tool], llm=MockLLM()) - return AgentService( - agent, - message_queue=message_queue, - description="Test Agent Server", - host="https://mock-agent-service.io", - port=8000, - ) - - -@pytest.fixture() -def task_step_output() -> TaskStepOutput: - return TaskStepOutput( - output=AgentChatResponse(response="A baby llama is called a 'Cria'."), - task_step=TaskStep(task_id="", step_id=""), - next_steps=[], - is_last=True, - ) - - -@pytest.fixture() -def completed_task() -> Task: - return Task( - task_id="", - input="What is the secret fact?", - memory=ChatMemoryBuffer.from_defaults(), - ) - - -@pytest.mark.asyncio() -@patch.object(ReActAgent, "arun_step") -@patch.object(ReActAgent, "get_completed_tasks") -async def test_tool_call_output( - mock_get_completed_tasks: MagicMock, - mock_arun_step: AsyncMock, - message_queue: SimpleMessageQueueServer, - agent_service: AgentService, - task_step_output: TaskStepOutput, - completed_task: Task, -) -> None: - # arrange - def arun_side_effect(task_id: str) -> TaskStepOutput: - completed_task.task_id = task_id - task_step_output.task_step.task_id = task_id - return task_step_output - - mock_arun_step.side_effect = arun_side_effect - mock_get_completed_tasks.side_effect = [ - [], - [], - [completed_task], - [completed_task], - [completed_task], - ] - - agent_service_tool = ServiceAsTool.from_service_definition( - message_queue=message_queue, - service_definition=agent_service.service_definition, - ) - - # startup - await message_queue.register_consumer(agent_service.as_consumer()) - mq_task = asyncio.create_task(message_queue.processing_loop()) - as_task = asyncio.create_task(agent_service.processing_loop()) - - # act - tool_output = await agent_service_tool.acall(input="What is the secret fact?") - - # clean-up/shutdown - await asyncio.sleep(0.1) - mq_task.cancel() - as_task.cancel() - - # assert - assert tool_output.content == "A baby llama is called a 'Cria'." - assert tool_output.tool_name == agent_service_tool.metadata.name - assert tool_output.raw_input == { - "args": (), - "kwargs": {"input": "What is the secret fact?"}, - } - assert len(agent_service_tool.tool_call_results) == 0 - assert agent_service_tool.registered is True - - -@pytest.mark.asyncio() -@patch.object(ReActAgent, "arun_step") -@patch.object(ReActAgent, "get_completed_tasks") -async def test_tool_call_raises_timeout_error( - mock_get_completed_tasks: MagicMock, - mock_arun_step: AsyncMock, - message_queue: SimpleMessageQueueServer, - agent_service: AgentService, - task_step_output: TaskStepOutput, - completed_task: Task, -) -> None: - # arrange - def arun_side_effect(task_id: str) -> TaskStepOutput: - completed_task.task_id = task_id - task_step_output.task_step.task_id = task_id - return task_step_output - - mock_arun_step.side_effect = arun_side_effect - mock_get_completed_tasks.side_effect = [ - [], - [], - [completed_task], - [completed_task], - [completed_task], - ] - - agent_service_tool = ServiceAsTool.from_service_definition( - message_queue=message_queue, - service_definition=agent_service.service_definition, - timeout=1e-12, - raise_timeout=True, - ) - - # startup - await message_queue.register_consumer(agent_service.as_consumer()) - mq_task = asyncio.create_task(message_queue.processing_loop()) - as_task = asyncio.create_task(agent_service.processing_loop()) - - # act/assert - with pytest.raises( - (TimeoutError, asyncio.TimeoutError, asyncio.exceptions.TimeoutError) - ): - await agent_service_tool.acall(input="What is the secret fact?") - - # clean-up/shutdown - mq_task.cancel() - as_task.cancel() - - -@pytest.mark.asyncio() -@patch.object(ReActAgent, "arun_step") -@patch.object(ReActAgent, "get_completed_tasks") -async def test_tool_call_hits_timeout_but_returns_tool_output( - mock_get_completed_tasks: MagicMock, - mock_arun_step: AsyncMock, - message_queue: SimpleMessageQueueServer, - agent_service: AgentService, - task_step_output: TaskStepOutput, - completed_task: Task, -) -> None: - # arrange - def arun_side_effect(task_id: str) -> TaskStepOutput: - completed_task.task_id = task_id - task_step_output.task_step.task_id = task_id - return task_step_output - - mock_arun_step.side_effect = arun_side_effect - mock_get_completed_tasks.side_effect = [ - [], - [], - [completed_task], - [completed_task], - [completed_task], - ] - - agent_service_tool = ServiceAsTool.from_service_definition( - message_queue=message_queue, - service_definition=agent_service.service_definition, - timeout=1e-12, - raise_timeout=False, - ) - - # startup - await message_queue.register_consumer(agent_service.as_consumer()) - mq_task = asyncio.create_task(message_queue.processing_loop()) - as_task = asyncio.create_task(agent_service.processing_loop()) - - # act/assert - tool_output = await agent_service_tool.acall(input="What is the secret fact?") - - # clean-up/shutdown - mq_task.cancel() - as_task.cancel() - - assert "Encountered error" in tool_output.content - assert tool_output.is_error - assert tool_output.tool_name == agent_service_tool.metadata.name - assert tool_output.raw_input == { - "args": (), - "kwargs": {"input": "What is the secret fact?"}, - } - assert len(agent_service_tool.tool_call_results) == 0 - assert agent_service_tool.registered is True diff --git a/tests/tools/test_human_service_as_tool.py b/tests/tools/test_human_service_as_tool.py deleted file mode 100644 index aacda81c..00000000 --- a/tests/tools/test_human_service_as_tool.py +++ /dev/null @@ -1,145 +0,0 @@ -import asyncio -import time -from unittest.mock import MagicMock, patch - -import pytest - -from llama_deploy.message_queues.simple import SimpleMessageQueueServer -from llama_deploy.services.human import HumanService -from llama_deploy.tools.service_as_tool import ServiceAsTool - - -@pytest.fixture() -def message_queue() -> SimpleMessageQueueServer: - return SimpleMessageQueueServer() - - -pytestmark = pytest.mark.skip - - -@pytest.fixture() -def human_service(message_queue: SimpleMessageQueueServer) -> HumanService: - return HumanService( - message_queue=message_queue, - description="Test Human Service", - service_name="test_human_service", - host="https://mock-human-service.io", - port=8000, - ) - - -@pytest.mark.asyncio() -@patch("builtins.input") -async def test_tool_call_output( - mock_input: MagicMock, - message_queue: SimpleMessageQueueServer, - human_service: HumanService, -) -> None: - # arrange - human_service_as_tool = ServiceAsTool.from_service_definition( - message_queue=message_queue, - service_definition=human_service.service_definition, - ) - mock_input.return_value = "Test human input." - - # startup - await message_queue.register_consumer(human_service.as_consumer()) - mq_task = asyncio.create_task(message_queue.processing_loop()) - hs_task = asyncio.create_task(human_service.processing_loop()) - - # act - tool_output = await human_service_as_tool.acall(input="Mock human request") - - # clean-up/shutdown - await asyncio.sleep(0.1) - mq_task.cancel() - hs_task.cancel() - - # assert - assert tool_output.content == "Test human input." - assert tool_output.tool_name == human_service_as_tool.metadata.name - assert tool_output.raw_input == { - "args": (), - "kwargs": {"input": "Mock human request"}, - } - assert len(human_service_as_tool.tool_call_results) == 0 - assert human_service_as_tool.registered is True - - -@pytest.mark.asyncio() -@patch("builtins.input") -async def test_tool_call_raises_timeout_error( - mock_input: MagicMock, - message_queue: SimpleMessageQueueServer, - human_service: HumanService, -) -> None: - # arrange - def input_side_effect(prompt: str) -> str: - time.sleep(0.1) - return prompt - - mock_input.side_effect = input_side_effect - human_service_as_tool = ServiceAsTool.from_service_definition( - message_queue=message_queue, - service_definition=human_service.service_definition, - timeout=1e-12, - raise_timeout=True, - ) - - # startup - await message_queue.register_consumer(human_service.as_consumer()) - mq_task = asyncio.create_task(message_queue.processing_loop()) - hs_task = asyncio.create_task(human_service.processing_loop()) - - # act/assert - with pytest.raises( - (TimeoutError, asyncio.TimeoutError, asyncio.exceptions.TimeoutError) - ): - await human_service_as_tool.acall(input="Is this a mock request?") - - # clean-up/shutdown - mq_task.cancel() - hs_task.cancel() - - -@pytest.mark.asyncio() -@patch("builtins.input") -async def test_tool_call_hits_timeout_but_returns_tool_output( - mock_input: MagicMock, - message_queue: SimpleMessageQueueServer, - human_service: HumanService, -) -> None: - # arrange - def input_side_effect(prompt: str) -> str: - time.sleep(0.1) - return prompt - - mock_input.side_effect = input_side_effect - human_service_as_tool = ServiceAsTool.from_service_definition( - message_queue=message_queue, - service_definition=human_service.service_definition, - timeout=1e-12, - raise_timeout=False, - ) - - # startup - await message_queue.register_consumer(human_service.as_consumer()) - mq_task = asyncio.create_task(message_queue.processing_loop()) - hs_task = asyncio.create_task(human_service.processing_loop()) - - # act/assert - tool_output = await human_service_as_tool.acall(input="Is this a mock request?") - - # clean-up/shutdown - mq_task.cancel() - hs_task.cancel() - - assert "Encountered error" in tool_output.content - assert tool_output.is_error - assert tool_output.tool_name == human_service_as_tool.metadata.name - assert tool_output.raw_input == { - "args": (), - "kwargs": {"input": "Is this a mock request?"}, - } - assert len(human_service_as_tool.tool_call_results) == 0 - assert human_service_as_tool.registered is True diff --git a/tests/tools/test_meta_service_tool.py b/tests/tools/test_meta_service_tool.py deleted file mode 100644 index 19e44db6..00000000 --- a/tests/tools/test_meta_service_tool.py +++ /dev/null @@ -1,223 +0,0 @@ -import asyncio -from typing import Any, Dict, List - -import pytest -from llama_index.core.tools import BaseTool, FunctionTool -from pydantic import PrivateAttr - -from llama_deploy.message_consumers.base import BaseMessageQueueConsumer -from llama_deploy.message_queues.simple import SimpleMessageQueueServer -from llama_deploy.messages.base import QueueMessage -from llama_deploy.services import ToolService -from llama_deploy.tools import MetaServiceTool - -pytestmark = pytest.mark.skip - - -class MockMessageConsumer(BaseMessageQueueConsumer): - processed_messages: List[QueueMessage] = [] - _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) - - async def _process_message(self, message: QueueMessage, **kwargs: Any) -> None: - async with self._lock: - self.processed_messages.append(message) - - -@pytest.fixture() -def tools() -> List[BaseTool]: - def multiply(a: int, b: int) -> int: - """Multiple two integers and returns the result integer""" - return a * b - - return [FunctionTool.from_defaults(fn=multiply)] - - -@pytest.fixture() -def message_queue() -> SimpleMessageQueueServer: - return SimpleMessageQueueServer() - - -@pytest.fixture() -def tool_service( - message_queue: SimpleMessageQueueServer, tools: List[BaseTool] -) -> ToolService: - return ToolService( - message_queue=message_queue, - tools=tools, - running=True, - service_name="test_tool_service", - description="Test Tool Server", - step_interval=0.5, - host="localhost", - port=8001, - ) - - -@pytest.mark.asyncio() -async def test_init( - message_queue: SimpleMessageQueueServer, - tools: List[BaseTool], - tool_service: ToolService, -) -> None: - # arrange - result = await tool_service.get_tool_by_name("multiply") - - # act - meta_service_tool = MetaServiceTool( - tool_metadata=result["tool_metadata"], - message_queue=message_queue, - tool_service_name=tool_service.service_name, - ) - - # assert - assert meta_service_tool.metadata.name == "multiply" - assert not meta_service_tool.registered - - -@pytest.mark.asyncio() -async def test_create_from_tool_service_direct( - message_queue: SimpleMessageQueueServer, tool_service: ToolService -) -> None: - # arrange - - # act - meta_service_tool: MetaServiceTool = await MetaServiceTool.from_tool_service( - tool_service=tool_service, message_queue=message_queue, name="multiply" - ) - - # assert - assert meta_service_tool.metadata.name == "multiply" - assert not meta_service_tool.registered - - -@pytest.mark.asyncio() -@pytest.mark.parametrize( - ("from_tool_service_kwargs"), - [ - {"message_queue": SimpleMessageQueueServer(), "name": "multiply"}, - { - "message_queue": SimpleMessageQueueServer(), - "name": "multiply", - "tool_service_name": "fake-name", - }, - { - "message_queue": SimpleMessageQueueServer(), - "name": "multiply", - "tool_service_api_key": "fake-key", - }, - { - "message_queue": SimpleMessageQueueServer(), - "name": "multiply", - "tool_service_url": "fake-url", - }, - { - "message_queue": SimpleMessageQueueServer(), - "name": "multiply", - "tool_service_name": "fake-name", - "tool_service_api_key": "fake-key", - }, - { - "message_queue": SimpleMessageQueueServer(), - "name": "multiply", - "tool_service_name": "fake-name", - "tool_service_url": "fake-url", - }, - { - "message_queue": SimpleMessageQueueServer(), - "name": "multiply", - "tool_service_api_key": "fake-key", - "tool_service_url": "fake-url", - }, - ], -) -async def test_create_from_tool_service_raise_error( - from_tool_service_kwargs: Dict[str, Any], -) -> None: - # arrange - # act/assert - with pytest.raises(ValueError): - await MetaServiceTool.from_tool_service(**from_tool_service_kwargs) - - -@pytest.mark.asyncio() -async def test_tool_call_output( - message_queue: SimpleMessageQueueServer, tool_service: ToolService -) -> None: - # arrange - meta_service_tool: MetaServiceTool = await MetaServiceTool.from_tool_service( - tool_service=tool_service, message_queue=message_queue, name="multiply" - ) - await message_queue.register_consumer(tool_service.as_consumer()) - mq_task = await message_queue.launch_local() - ts_task = asyncio.create_task(tool_service.processing_loop()) - - # act - tool_output = await meta_service_tool.acall(a=1, b=9) - - # clean-up/shutdown - await asyncio.sleep(0.5) - mq_task.cancel() - ts_task.cancel() - - # assert - assert tool_output.content == "9" - assert tool_output.tool_name == "multiply" - assert tool_output.raw_input == {"args": (), "kwargs": {"a": 1, "b": 9}} - assert len(meta_service_tool.tool_call_results) == 0 - assert meta_service_tool.registered - - -@pytest.mark.asyncio() -async def test_tool_call_raise_timeout( - message_queue: SimpleMessageQueueServer, tool_service: ToolService -) -> None: - # arrange - meta_service_tool: MetaServiceTool = await MetaServiceTool.from_tool_service( - tool_service=tool_service, - message_queue=message_queue, - name="multiply", - timeout=1e-9, - raise_timeout=True, - ) - await message_queue.register_consumer(tool_service.as_consumer()) - mq_task = await message_queue.launch_local() - ts_task = asyncio.create_task(tool_service.processing_loop()) - - # act/assert - with pytest.raises( - (TimeoutError, asyncio.TimeoutError, asyncio.exceptions.TimeoutError) - ): - await meta_service_tool.acall(a=1, b=9) - - mq_task.cancel() - ts_task.cancel() - - -@pytest.mark.asyncio() -async def test_tool_call_reach_timeout( - message_queue: SimpleMessageQueueServer, tool_service: ToolService -) -> None: - # arrange - meta_service_tool: MetaServiceTool = await MetaServiceTool.from_tool_service( - tool_service=tool_service, - message_queue=message_queue, - name="multiply", - timeout=1e-9, - raise_timeout=False, - ) - await message_queue.register_consumer(tool_service.as_consumer()) - mq_task = await message_queue.launch_local() - ts_task = asyncio.create_task(tool_service.processing_loop()) - - # act/assert - tool_output = await meta_service_tool.acall(a=1, b=9) - - mq_task.cancel() - ts_task.cancel() - - assert "Encountered error" in tool_output.content - assert tool_output.tool_name == "multiply" - assert tool_output.is_error - assert tool_output.raw_input == {"args": (), "kwargs": {"a": 1, "b": 9}} - assert len(meta_service_tool.tool_call_results) == 0 - assert meta_service_tool.registered diff --git a/tests/tools/test_service_as_tool.py b/tests/tools/test_service_as_tool.py deleted file mode 100644 index 865f8a2f..00000000 --- a/tests/tools/test_service_as_tool.py +++ /dev/null @@ -1,135 +0,0 @@ -import pytest -from llama_index.core.agent import ReActAgent -from llama_index.core.llms import MockLLM -from llama_index.core.tools import FunctionTool, ToolMetadata - -from llama_deploy.message_queues.simple import SimpleMessageQueueServer -from llama_deploy.services.agent import AgentService -from llama_deploy.services.human import HumanService -from llama_deploy.tools.service_as_tool import ServiceAsTool - - -@pytest.fixture() -def message_queue() -> SimpleMessageQueueServer: - return SimpleMessageQueueServer() - - -@pytest.fixture() -def human_service(message_queue: SimpleMessageQueueServer) -> HumanService: - return HumanService( - message_queue=message_queue, - running=False, - description="Test Human Service", - service_name="test_human_service", - host="https://mock-human-service.io", - port=8000, - ) - - -@pytest.fixture() -def agent_service(message_queue: SimpleMessageQueueServer) -> AgentService: - # create an agent - def get_the_secret_fact() -> str: - """Returns the secret fact.""" - return "The secret fact is: A baby llama is called a 'Cria'." - - tool = FunctionTool.from_defaults(fn=get_the_secret_fact) - - agent = ReActAgent.from_tools([tool], llm=MockLLM()) - return AgentService( - agent, - message_queue=message_queue, - description="Test Agent Server", - host="https://mock-agent-service.io", - port=8000, - ) - - -@pytest.mark.parametrize( - ("service_type"), - ["human_service", "agent_service"], -) -def test_init( - message_queue: SimpleMessageQueueServer, - service_type: str, - request: pytest.FixtureRequest, -) -> None: - # arrange - service = request.getfixturevalue(service_type) - tool_metadata = ToolMetadata( - description=service.description, - name=service.tool_name, - ) - # act - agent_service_tool = ServiceAsTool( - tool_metadata=tool_metadata, - message_queue=message_queue, - service_name=service.service_name, - timeout=5.5, - step_interval=0.5, - ) - - # assert - assert agent_service_tool.step_interval == 0.5 - assert agent_service_tool.message_queue == message_queue - assert agent_service_tool.metadata == tool_metadata - assert agent_service_tool.timeout == 5.5 - assert agent_service_tool.service_name == service.service_name - assert agent_service_tool.registered is False - - -@pytest.mark.parametrize( - ("service_type"), - ["human_service", "agent_service"], -) -def test_init_invalid_tool_name_should_raise_error( - message_queue: SimpleMessageQueueServer, - service_type: str, - request: pytest.FixtureRequest, -) -> None: - # arrange - service = request.getfixturevalue(service_type) - tool_metadata = ToolMetadata( - description=service.description, - name="incorrect-name", - ) - # act/assert - with pytest.raises(ValueError): - ServiceAsTool( - tool_metadata=tool_metadata, - message_queue=message_queue, - service_name=service.service_name, - ) - - -@pytest.mark.parametrize( - ("service_type"), - ["human_service", "agent_service"], -) -def test_from_service_definition( - message_queue: SimpleMessageQueueServer, - service_type: str, - request: pytest.FixtureRequest, -) -> None: - # arrange - service = request.getfixturevalue(service_type) - service_def = service.service_definition - - # act - agent_service_tool = ServiceAsTool.from_service_definition( - message_queue=message_queue, - service_definition=service_def, - timeout=5.5, - step_interval=0.5, - raise_timeout=True, - ) - - # assert - assert agent_service_tool.step_interval == 0.5 - assert agent_service_tool.message_queue == message_queue - assert agent_service_tool.metadata.description == service_def.description - assert agent_service_tool.metadata.name == f"{service_def.service_name}-as-tool" - assert agent_service_tool.timeout == 5.5 - assert agent_service_tool.service_name == service.service_name - assert agent_service_tool.raise_timeout is True - assert agent_service_tool.registered is False From 59efa8d64dda892da388d5b03e24e24eda83b6c4 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Mon, 23 Dec 2024 09:43:29 +0100 Subject: [PATCH 2/4] removed unused services --- llama_deploy/__init__.py | 4 - llama_deploy/services/__init__.py | 16 -- llama_deploy/services/component.py | 264 ---------------------- llama_deploy/services/tool.py | 335 ---------------------------- llama_deploy/services/types.py | 120 ---------- tests/services/test_tool_service.py | 181 --------------- 6 files changed, 920 deletions(-) delete mode 100644 llama_deploy/services/component.py delete mode 100644 llama_deploy/services/tool.py delete mode 100644 llama_deploy/services/types.py delete mode 100644 tests/services/test_tool_service.py diff --git a/llama_deploy/__init__.py b/llama_deploy/__init__.py index 788ff1de..a7ca6af8 100644 --- a/llama_deploy/__init__.py +++ b/llama_deploy/__init__.py @@ -12,8 +12,6 @@ from llama_deploy.messages import QueueMessage from llama_deploy.orchestrators import SimpleOrchestrator, SimpleOrchestratorConfig from llama_deploy.services import ( - ComponentService, - ToolService, WorkflowService, WorkflowServiceConfig, ) @@ -35,8 +33,6 @@ "AsyncLlamaDeployClient", "Client", # services - "ToolService", - "ComponentService", "WorkflowService", "WorkflowServiceConfig", # messages diff --git a/llama_deploy/services/__init__.py b/llama_deploy/services/__init__.py index 45f64223..41530c3e 100644 --- a/llama_deploy/services/__init__.py +++ b/llama_deploy/services/__init__.py @@ -1,24 +1,8 @@ from llama_deploy.services.base import BaseService -from llama_deploy.services.component import ComponentService -from llama_deploy.services.tool import ToolService -from llama_deploy.services.types import ( - _ChatMessage, - _Task, - _TaskSate, - _TaskStep, - _TaskStepOutput, -) from llama_deploy.services.workflow import WorkflowService, WorkflowServiceConfig __all__ = [ "BaseService", - "ToolService", - "ComponentService", "WorkflowService", "WorkflowServiceConfig", - "_Task", - "_TaskSate", - "_TaskStep", - "_TaskStepOutput", - "_ChatMessage", ] diff --git a/llama_deploy/services/component.py b/llama_deploy/services/component.py deleted file mode 100644 index e5f72d9e..00000000 --- a/llama_deploy/services/component.py +++ /dev/null @@ -1,264 +0,0 @@ -import asyncio -import json -import uuid -from contextlib import asynccontextmanager -from logging import getLogger -from typing import Any, AsyncGenerator, Dict, Optional - -import uvicorn -from fastapi import FastAPI -from llama_index.core.bridge.pydantic import PrivateAttr -from llama_index.core.query_pipeline import QueryComponent - -from llama_deploy.control_plane.server import CONTROL_PLANE_MESSAGE_TYPE -from llama_deploy.message_consumers.base import BaseMessageQueueConsumer -from llama_deploy.message_consumers.callable import CallableMessageConsumer -from llama_deploy.message_consumers.remote import RemoteMessageConsumer -from llama_deploy.message_publishers.publisher import PublishCallback -from llama_deploy.message_queues.base import BaseMessageQueue -from llama_deploy.messages.base import QueueMessage -from llama_deploy.services.base import BaseService -from llama_deploy.types import ( - ActionTypes, - ServiceDefinition, - TaskDefinition, - TaskResult, -) - -logger = getLogger(__name__) - - -class ComponentService(BaseService): - """Component service. - - Wraps a query pipeline component into a service. - - Exposes the following endpoints: - - GET `/`: Home endpoint. - - POST `/process_message`: Process a message. - - Attributes: - component (Any): The query pipeline component. - description (str): The description of the service. - running (bool): Whether the service is running. - step_interval (float): The interval in seconds to poll for tool call results. Defaults to 0.1s. - host (Optional[str]): The host of the service. - port (Optional[int]): The port of the service. - raise_exceptions (bool): Whether to raise exceptions. - - Examples: - ```python - from llama_deploy import ComponentService - from llama_index.core.query_pipeline import QueryComponent - - component_service = ComponentService( - component=query_component, - message_queue=message_queue, - description="component_service", - service_name="my_component_service", - ) - ``` - """ - - service_name: str - component: Any - - description: str = "Component service." - running: bool = True - step_interval: float = 0.1 - host: str - port: int - raise_exceptions: bool = False - - _message_queue: BaseMessageQueue = PrivateAttr() - _app: FastAPI = PrivateAttr() - _publisher_id: str = PrivateAttr() - _publish_callback: Optional[PublishCallback] = PrivateAttr() - _lock: asyncio.Lock = PrivateAttr() - _outstanding_calls: Dict[str, Any] = PrivateAttr() - - def __init__( - self, - component: Any, - message_queue: BaseMessageQueue, - running: bool = True, - description: str = "Component Server", - service_name: str = "default_component_service", - publish_callback: Optional[PublishCallback] = None, - step_interval: float = 0.1, - host: Optional[str] = None, - port: Optional[int] = None, - raise_exceptions: bool = False, - ) -> None: - # HACK: QueryComponent is on pydantic v1 - if not isinstance(component, QueryComponent): - raise ValueError("Component must be a QueryComponent") - - super().__init__( - component=component, - running=running, - description=description, - service_name=service_name, - step_interval=step_interval, - host=host, - port=port, - raise_exceptions=raise_exceptions, - ) - - self._lock = asyncio.Lock() - # self._tasks_as_tool_calls = {} - self._message_queue = message_queue - self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}" - self._publish_callback = publish_callback - - self._outstanding_calls: Dict[str, Any] = {} - - self._app = FastAPI(lifespan=self.lifespan) - - self._app.add_api_route( - "/", self.home, methods=["GET"], tags=["Component Service"] - ) - - self._app.add_api_route( - "/process_message", - self.process_message, - methods=["POST"], - tags=["Message Processing"], - ) - - @property - def service_definition(self) -> ServiceDefinition: - """Service definition.""" - return ServiceDefinition( - service_name=self.service_name, - description=self.description, - host=self.host, - port=self.port, - ) - - @property - def message_queue(self) -> BaseMessageQueue: - """Message queue.""" - return self._message_queue - - @property - def publisher_id(self) -> str: - """Publisher ID.""" - return self._publisher_id - - @property - def publish_callback(self) -> Optional[PublishCallback]: - """Publish callback, if any.""" - return self._publish_callback - - @property - def lock(self) -> asyncio.Lock: - return self._lock - - async def processing_loop(self) -> None: - """The processing loop for the service.""" - logger.info("Processing initiated.") - while True: - if not self.running: - await asyncio.sleep(self.step_interval) - continue - - async with self.lock: - current_calls = [(t, c) for t, c in self._outstanding_calls.items()] - - for task_id, current_call in current_calls: - output_dict = await self.component.arun_component(**current_call) - - await self.message_queue.publish( - QueueMessage( - type=CONTROL_PLANE_MESSAGE_TYPE, - action=ActionTypes.COMPLETED_TASK, - data=TaskResult( - task_id=task_id, - history=[], - result=json.dumps(output_dict), - data=output_dict, - ).model_dump(), - ), - topic=self.get_topic(CONTROL_PLANE_MESSAGE_TYPE), - ) - - # clean up - async with self.lock: - del self._outstanding_calls[task_id] - - await asyncio.sleep(self.step_interval) - - async def process_message(self, message: QueueMessage) -> None: - """Process a message received from the message queue.""" - if message.action == ActionTypes.NEW_TASK: - task_def = TaskDefinition(**message.data or {}) - input_dict = json.loads(task_def.input) - async with self.lock: - self._outstanding_calls[task_def.task_id] = input_dict["__input_dict__"] - else: - raise ValueError(f"Unhandled action: {message.action}") - - def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: - """Get the consumer for the message queue. - - Args: - remote (bool): - Whether the consumer is remote. Defaults to False. - If True, the consumer will be a RemoteMessageConsumer that uses the `process_message` endpoint. - """ - if remote: - url = ( - f"http://{self.host}:{self.port}{self._app.url_path_for('process_message')}" - if self.port - else f"http://{self.host}{self._app.url_path_for('process_message')}" - ) - return RemoteMessageConsumer( - id_=self.publisher_id, - url=url, - message_type=self.service_name, - ) - - return CallableMessageConsumer( - id_=self.publisher_id, - message_type=self.service_name, - handler=self.process_message, - ) - - async def launch_local(self) -> asyncio.Task: - """Launch the service in-process.""" - logger.info(f"{self.service_name} launch_local") - return asyncio.create_task(self.processing_loop()) - - # ---- Server based methods ---- - - @asynccontextmanager - async def lifespan(self, app: FastAPI) -> AsyncGenerator[None, None]: - """Starts the processing loop when the fastapi app starts.""" - asyncio.create_task(self.processing_loop()) - yield - self.running = False - - async def home(self) -> Dict[str, str]: - """Home endpoint. Returns general information about the service.""" - return { - "service_name": self.service_name, - "description": self.description, - "running": str(self.running), - "step_interval": str(self.step_interval), - "num_outstanding_calls": str(len(self._outstanding_calls)), - "type": "component_service", - } - - async def launch_server(self) -> None: - """Launch the service as a FastAPI server.""" - logger.info(f"Launching {self.service_name} server at {self.host}:{self.port}") - # uvicorn.run(self._app, host=self.host, port=self.port) - - class CustomServer(uvicorn.Server): - def install_signal_handlers(self) -> None: - pass - - cfg = uvicorn.Config(self._app, host=self.host, port=self.port) - server = CustomServer(cfg) - await server.serve() diff --git a/llama_deploy/services/tool.py b/llama_deploy/services/tool.py deleted file mode 100644 index 11825270..00000000 --- a/llama_deploy/services/tool.py +++ /dev/null @@ -1,335 +0,0 @@ -import asyncio -import uuid -from asyncio import Lock -from asyncio.exceptions import CancelledError -from contextlib import asynccontextmanager -from logging import getLogger -from typing import Any, AsyncGenerator, Dict, List, Optional - -import uvicorn -from fastapi import FastAPI -from llama_index.core.agent.function_calling.step import ( - get_function_by_name, -) -from llama_index.core.llms import MessageRole -from llama_index.core.tools import AsyncBaseTool, BaseTool, adapt_to_async_tool -from pydantic import PrivateAttr - -from llama_deploy.message_consumers.base import BaseMessageQueueConsumer -from llama_deploy.message_consumers.callable import CallableMessageConsumer -from llama_deploy.message_consumers.remote import RemoteMessageConsumer -from llama_deploy.message_publishers.publisher import PublishCallback -from llama_deploy.message_queues.base import BaseMessageQueue -from llama_deploy.messages.base import QueueMessage -from llama_deploy.services.base import BaseService -from llama_deploy.types import ( - ActionTypes, - ChatMessage, - ServiceDefinition, - ToolCall, - ToolCallResult, -) - -logger = getLogger(__name__) - - -class ToolService(BaseService): - """A service that executes tools remotely for other services. - - This service is responsible for executing tools remotely for other services and agents. - - Exposes the following endpoints: - - GET `/`: Home endpoint. - - POST `/tool_call`: Create a tool call. - - GET `/tool`: Get a tool by name. - - POST `/process_message`: Process a message. - - Attributes: - tools (List[AsyncBaseTool]): - A list of tools to execute. - description (str): - The description of the tool service. - running (bool): - Whether the service is running. - step_interval (float): - The interval in seconds to poll for tool call results. Defaults to 0.1s. - host (Optional[str]): - The host of the service. - port (Optional[int]): - The port of the service. - - Examples: - ```python - from llama_deploy import ToolService, MetaServiceTool, SimpleMessageQueue - from llama_index.core.llms import OpenAI - from llama_index.core.agent import FunctionCallingAgentWorker - - message_queue = SimpleMessageQueue() - - tool_service = ToolService( - message_queue=message_queue, - tools=[tool], - running=True, - step_interval=0.5, - ) - - # create a meta tool and use it in any other agent - # this allows remote execution of that tool - meta_tool = MetaServiceTool( - tool_metadata=tool.metadata, - message_queue=message_queue, - tool_service_name=tool_service.service_name, - ) - agent = FunctionCallingAgentWorker.from_tools( - [meta_tool], - llm=OpenAI(), - ).as_agent() - ``` - """ - - service_name: str - tools: List[AsyncBaseTool] - description: str = "Local Tool Service." - running: bool = True - step_interval: float = 0.1 - host: str - port: int - - _outstanding_tool_calls: Dict[str, ToolCall] = PrivateAttr() - _message_queue: BaseMessageQueue = PrivateAttr() - _app: FastAPI = PrivateAttr() - _publisher_id: str = PrivateAttr() - _publish_callback: Optional[PublishCallback] = PrivateAttr() - _lock: Lock = PrivateAttr() - - def __init__( - self, - message_queue: BaseMessageQueue, - tools: Optional[List[BaseTool]] = None, - running: bool = True, - description: str = "Tool Server", - service_name: str = "default_tool_service", - publish_callback: Optional[PublishCallback] = None, - step_interval: float = 0.1, - host: Optional[str] = None, - port: Optional[int] = None, - ) -> None: - tools = tools or [] - tools = [adapt_to_async_tool(t) for t in tools] - super().__init__( - tools=tools, - running=running, - description=description, - service_name=service_name, - step_interval=step_interval, - host=host, - port=port, - ) - - self._outstanding_tool_calls = {} - self._message_queue = message_queue - self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}" - self._publish_callback = publish_callback - self._lock = asyncio.Lock() - self._app = FastAPI(lifespan=self.lifespan) - - self._app.add_api_route("/", self.home, methods=["GET"], tags=["Tool Service"]) - - self._app.add_api_route( - "/tool_call", self.create_tool_call, methods=["POST"], tags=["Tool Call"] - ) - - self._app.add_api_route( - "/tool", self.get_tool_by_name, methods=["GET"], tags=["Tool"] - ) - - self._app.add_api_route( - "/process_message", - self.process_message, - methods=["POST"], - tags=["Message Processing"], - ) - - @property - def service_definition(self) -> ServiceDefinition: - """The service definition.""" - return ServiceDefinition( - service_name=self.service_name, - description=self.description, - prompt=[], - host=self.host, - port=self.port, - ) - - @property - def message_queue(self) -> BaseMessageQueue: - """The message queue.""" - return self._message_queue - - @property - def publisher_id(self) -> str: - """The publisher ID.""" - return self._publisher_id - - @property - def publish_callback(self) -> Optional[PublishCallback]: - """The publish callback, if any.""" - return self._publish_callback - - @property - def lock(self) -> Lock: - return self._lock - - async def processing_loop(self) -> None: - """The processing loop for the service.""" - logger.info("Processing initiated.") - try: - await self._processing_loop() - except CancelledError: - logger.info("Processing loop cancelled...") - - async def _processing_loop(self) -> None: - while True: - if not self.running: - await asyncio.sleep(self.step_interval) - continue - - async with self.lock: - current_tool_calls: List[ToolCall] = [ - *self._outstanding_tool_calls.values() - ] - for tool_call in current_tool_calls: - tool = get_function_by_name( - self.tools, tool_call.tool_call_bundle.tool_name - ) - if tool is None: - continue - - logger.info( - f"Processing tool call id {tool_call.id_} with {tool.metadata.name}" - ) - tool_output = tool( - *tool_call.tool_call_bundle.tool_args, - **tool_call.tool_call_bundle.tool_kwargs, - ) - - # execute function call - tool_message = ChatMessage( - content=str(tool_output), - role=MessageRole.TOOL, - additional_kwargs={ - "name": tool_call.tool_call_bundle.tool_name, - "tool_call_id": tool_call.id_, - }, - ) - - # publish the completed task - await self.publish( - QueueMessage( - type=tool_call.source_id, - action=ActionTypes.COMPLETED_TOOL_CALL, - data=ToolCallResult( - id_=tool_call.id_, - tool_message=tool_message, - result=str(tool_output), - ).model_dump(), - ) - ) - - # clean up - async with self.lock: - del self._outstanding_tool_calls[tool_call.id_] - - await asyncio.sleep(self.step_interval) - - async def process_message(self, message: QueueMessage) -> None: - """Process a message.""" - if message.action == ActionTypes.NEW_TOOL_CALL: - tool_call_data = {"source_id": message.publisher_id} - tool_call_data.update(message.data or {}) - tool_call = ToolCall( - **tool_call_data # type: ignore - ) # FIXME: field `tool_bundle` is missing and not optional - async with self.lock: - self._outstanding_tool_calls.update({tool_call.id_: tool_call}) - else: - raise ValueError(f"Unhandled action: {message.action}") - - def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: - """Get the consumer for the service. - - Args: - remote (bool): - Whether the consumer is remote. Defaults to False. - If True, the consumer will be a RemoteMessageConsumer that uses the `process_message` endpoint. - """ - if remote: - url = ( - f"http://{self.host}:{self.port}{self._app.url_path_for('process_message')}" - if self.port - else f"http://{self.host}{self._app.url_path_for('process_message')}" - ) - return RemoteMessageConsumer( - id_=self.publisher_id, - url=url, - message_type=self.service_name, - ) - return CallableMessageConsumer( - id_=self.publisher_id, - message_type=self.service_name, - handler=self.process_message, - ) - - async def launch_local(self) -> asyncio.Task: - """Launch the service in-process.""" - return asyncio.create_task(self.processing_loop()) - - # ---- Server based methods ---- - - @asynccontextmanager - async def lifespan(self, app: FastAPI) -> AsyncGenerator[None, None]: - """Starts the processing loop when the fastapi app starts.""" - asyncio.create_task(self.processing_loop()) - yield - self.running = False - - async def home(self) -> Dict[str, str]: - """Home endpoint. Returns the general information about the service.""" - return { - "service_name": self.service_name, - "description": self.description, - "running": str(self.running), - "step_interval": str(self.step_interval), - "num_tools": str(len(self.tools)), - "num_outstanding_tool_calls": str(len(self._outstanding_tool_calls)), - "tool_calls": "\n".join( - [str(tool_call) for tool_call in self._outstanding_tool_calls.values()] - ), - "type": "tool_service", - } - - async def create_tool_call(self, tool_call: ToolCall) -> Dict[str, str]: - """Create a tool call.""" - async with self.lock: - self._outstanding_tool_calls.update({tool_call.id_: tool_call}) - return {"tool_call_id": tool_call.id_} - - async def get_tool_by_name(self, name: str) -> Dict[str, Any]: - """Get a tool by name.""" - name_to_tool = {tool.metadata.name: tool for tool in self.tools} - if name not in name_to_tool: - raise ValueError(f"Tool with name {name} not found") - return {"tool_metadata": name_to_tool[name].metadata} - - async def launch_server(self) -> None: - """Launch the service as a FastAPI server.""" - logger.info(f"Launching tool service server at {self.host}:{self.port}") - # uvicorn.run(self._app, host=self.host, port=self.port) - - class CustomServer(uvicorn.Server): - def install_signal_handlers(self) -> None: - pass - - cfg = uvicorn.Config(self._app, host=self.host, port=self.port) - server = CustomServer(cfg) - await server.serve() diff --git a/llama_deploy/services/types.py b/llama_deploy/services/types.py deleted file mode 100644 index fa6ae347..00000000 --- a/llama_deploy/services/types.py +++ /dev/null @@ -1,120 +0,0 @@ -from pydantic import BaseModel -from typing import Dict, List, Optional, cast - -from llama_index.core.agent.types import TaskStep, TaskStepOutput, Task -from llama_index.core.agent.runner.base import AgentState, TaskState - -from llama_deploy.types import ChatMessage - -# ------ FastAPI types ------ - - -class _Task(BaseModel): - task_id: str - input: Optional[str] - extra_state: dict - - @classmethod - def from_task(cls, task: Task) -> "_Task": - _extra_state = {} - for key, value in task.extra_state.items(): - _extra_state[key] = str(value) - - return cls(task_id=task.task_id, input=task.input, extra_state=_extra_state) - - -class _TaskStep(BaseModel): - task_id: str - step_id: str - input: Optional[str] - step_state: dict - prev_steps: List["_TaskStep"] - next_steps: List["_TaskStep"] - is_ready: bool - - @classmethod - def from_task_step(cls, task_step: TaskStep) -> "_TaskStep": - _step_state = {} - for key, value in task_step.step_state.items(): - _step_state[key] = str(value) - - return cls( - task_id=task_step.task_id, - step_id=task_step.step_id, - input=task_step.input, - step_state=_step_state, - prev_steps=[ - cls.from_task_step(cast(TaskStep, prev_step)) - for prev_step in task_step.prev_steps - ], - next_steps=[ - cls.from_task_step(cast(TaskStep, next_step)) - for next_step in task_step.next_steps - ], - is_ready=task_step.is_ready, - ) - - -class _TaskStepOutput(BaseModel): - output: str - task_step: _TaskStep - next_steps: List[_TaskStep] - is_last: bool - - @classmethod - def from_task_step_output(cls, step_output: TaskStepOutput) -> "_TaskStepOutput": - return cls( - output=str(step_output.output), - task_step=_TaskStep.from_task_step(step_output.task_step), - next_steps=[ - _TaskStep.from_task_step(next_step) - for next_step in step_output.next_steps - ], - is_last=step_output.is_last, - ) - - -class _TaskSate(BaseModel): - task: _Task - step_queue: List[_TaskStep] - completed_steps: List[_TaskStepOutput] - - @classmethod - def from_task_state(cls, task_state: TaskState) -> "_TaskSate": - return cls( - task=_Task.from_task(task_state.task), - step_queue=[ - _TaskStep.from_task_step(step) for step in list(task_state.step_queue) - ], - completed_steps=[ - _TaskStepOutput.from_task_step_output(step) - for step in task_state.completed_steps - ], - ) - - -class _AgentState(BaseModel): - task_dict: Dict[str, _TaskSate] - - @classmethod - def from_agent_state(cls, agent_state: AgentState) -> "_AgentState": - return cls( - task_dict={ - task_id: _TaskSate.from_task_state(task_state) - for task_id, task_state in agent_state.task_dict.items() - } - ) - - -class _ChatMessage(BaseModel): - content: str - role: str - additional_kwargs: dict - - @classmethod - def from_chat_message(cls, chat_message: ChatMessage) -> "_ChatMessage": - return cls( - content=str(chat_message.content), - role=str(chat_message.role), - additional_kwargs=chat_message.additional_kwargs, - ) diff --git a/tests/services/test_tool_service.py b/tests/services/test_tool_service.py deleted file mode 100644 index b2672226..00000000 --- a/tests/services/test_tool_service.py +++ /dev/null @@ -1,181 +0,0 @@ -import asyncio -from typing import Any, List - -import pytest -from llama_index.core.tools import BaseTool, FunctionTool -from pydantic import PrivateAttr - -from llama_deploy.message_consumers.base import BaseMessageQueueConsumer -from llama_deploy.message_queues.simple import SimpleMessageQueue -from llama_deploy.messages.base import QueueMessage -from llama_deploy.services import ToolService -from llama_deploy.types import ActionTypes, ToolCall, ToolCallBundle - -TOOL_CALL_SOURCE = "mock-source" - - -class MockMessageConsumer(BaseMessageQueueConsumer): - processed_messages: List[QueueMessage] = [] - _lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) - - async def _process_message(self, message: QueueMessage, **kwargs: Any) -> None: - async with self._lock: - self.processed_messages.append(message) - - -@pytest.fixture() -def tools() -> List[BaseTool]: - def multiply(a: int, b: int) -> int: - """Multiple two integers and returns the result integer""" - return a * b - - return [FunctionTool.from_defaults(fn=multiply)] - - -@pytest.fixture() -def tool_call() -> ToolCall: - tool_bundle = ToolCallBundle( - tool_name="multiply", tool_args=[], tool_kwargs={"a": 1, "b": 2} - ) - return ToolCall(tool_call_bundle=tool_bundle, source_id=TOOL_CALL_SOURCE) - - -@pytest.fixture() -def tool_output_consumer() -> MockMessageConsumer: - return MockMessageConsumer(message_type=TOOL_CALL_SOURCE) - - -@pytest.mark.asyncio() -async def test_init(tools: List[BaseTool]) -> None: - # arrange - server = ToolService( - SimpleMessageQueue(), # type: ignore - tools=tools, - running=False, - description="Test Tool Server", - step_interval=0.5, - host="localhost", - port=8001, - ) - - # act - result = await server.get_tool_by_name("multiply") - multiply_tool_metadata = result["tool_metadata"] - - # assert - assert server.tools == tools - assert multiply_tool_metadata == tools[0].metadata - assert server.running is False - assert server.description == "Test Tool Server" - assert server.step_interval == 0.5 - - -@pytest.mark.asyncio() -async def test_create_tool_call(tools: List[BaseTool], tool_call: ToolCall) -> None: - # arrange - server = ToolService( - SimpleMessageQueue(), # type:ignore - tools=tools, - running=False, - description="Test Tool Server", - step_interval=0.5, - host="localhost", - port=8001, - ) - - # act - result = await server.create_tool_call(tool_call) - - # assert - assert result == {"tool_call_id": tool_call.id_} - assert server._outstanding_tool_calls[tool_call.id_] == tool_call - - -@pytest.mark.asyncio() -async def test_process_tool_call( - tools: List[BaseTool], - tool_call: ToolCall, - tool_output_consumer: MockMessageConsumer, - message_queue_server: Any, -) -> None: - # arrange - mq = SimpleMessageQueue() - server = ToolService( - mq, # type:ignore - tools=tools, - running=True, - description="Test Tool Server", - step_interval=0.5, - host="localhost", - port=8001, - ) - consumer_fn = await mq.register_consumer( - tool_output_consumer, topic="llama_deploy.mock-source" - ) - consumer_task = asyncio.create_task(consumer_fn()) - server_task = await server.launch_local() - - # act - result = await server.create_tool_call(tool_call) - - # Give some time for last message to get published and sent to consumers - await asyncio.sleep(1) - consumer_task.cancel() - server_task.cancel() - await asyncio.gather(consumer_task, server_task) - - # assert - assert server.message_queue == mq - assert result == {"tool_call_id": tool_call.id_} - assert len(tool_output_consumer.processed_messages) == 1 - assert tool_output_consumer.processed_messages[0].data.get("result") == "2" - - -@pytest.mark.asyncio() -async def test_process_tool_call_from_queue( - tools: List[BaseTool], - tool_call: ToolCall, - tool_output_consumer: MockMessageConsumer, - message_queue_server: Any, -) -> None: - # arrange - mq = SimpleMessageQueue() - server = ToolService( - mq, # type:ignore - tools=tools, - running=True, - service_name="test_tool_service", - description="Test Tool Server", - step_interval=0.5, - host="localhost", - port=8001, - ) - consumer_fn = await mq.register_consumer( - tool_output_consumer, topic="llama_deploy.mock-source" - ) - consumer_task = asyncio.create_task(consumer_fn()) - service_consumer_fn = await mq.register_consumer( - server.as_consumer(), topic="test_tool_service" - ) - service_consumer_task = asyncio.create_task(service_consumer_fn()) - server_task = await server.launch_local() - - # act - tool_call_message = QueueMessage( - data=tool_call.model_dump(), - action=ActionTypes.NEW_TOOL_CALL, - type="test_tool_service", - ) - await mq.publish(tool_call_message, topic="test_tool_service") - - # Give some time for last message to get published and sent to consumers - await asyncio.sleep(1) - consumer_task.cancel() - service_consumer_task.cancel() - server_task.cancel() - await asyncio.gather(consumer_task, service_consumer_task, server_task) - - # assert - assert server.message_queue == mq - assert len(tool_output_consumer.processed_messages) == 1 - assert tool_output_consumer.processed_messages[0].data.get("result") == "2" From 5175d8f847f4653b62ccc1374c89f2bc4fba441a Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Mon, 23 Dec 2024 09:44:47 +0100 Subject: [PATCH 3/4] remove unused package --- llama_deploy/sdk/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 llama_deploy/sdk/__init__.py diff --git a/llama_deploy/sdk/__init__.py b/llama_deploy/sdk/__init__.py deleted file mode 100644 index e69de29b..00000000 From c14a2f77915dfb8b4f3501f2677334aaee65e4d7 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Mon, 23 Dec 2024 09:59:47 +0100 Subject: [PATCH 4/4] remove more unused modules --- llama_deploy/types/__init__.py | 2 - llama_deploy/types/core.py | 98 ++++++++----------------------- llama_deploy/utils.py | 7 --- tests/orchestrators/test_utils.py | 9 +++ 4 files changed, 32 insertions(+), 84 deletions(-) delete mode 100644 llama_deploy/utils.py create mode 100644 tests/orchestrators/test_utils.py diff --git a/llama_deploy/types/__init__.py b/llama_deploy/types/__init__.py index d959bbf4..105244ab 100644 --- a/llama_deploy/types/__init__.py +++ b/llama_deploy/types/__init__.py @@ -5,7 +5,6 @@ ChatMessage, EventDefinition, HumanResponse, - MessageRole, PydanticValidatedUrl, ServiceDefinition, SessionDefinition, @@ -24,7 +23,6 @@ "ChatMessage", "EventDefinition", "HumanResponse", - "MessageRole", "PydanticValidatedUrl", "ServiceDefinition", "SessionDefinition", diff --git a/llama_deploy/types/core.py b/llama_deploy/types/core.py index 2ae4923c..f6f6a26e 100644 --- a/llama_deploy/types/core.py +++ b/llama_deploy/types/core.py @@ -1,11 +1,10 @@ import uuid from enum import Enum -from pydantic import BaseModel, Field, BeforeValidator, HttpUrl, TypeAdapter -from pydantic.v1 import BaseModel as V1BaseModel -from typing import Any, Dict, List, Optional, Union -from typing_extensions import Annotated +from typing import Any -from llama_index.core.llms import MessageRole +from llama_index.core.llms import ChatMessage +from pydantic import BaseModel, BeforeValidator, Field, HttpUrl, TypeAdapter +from typing_extensions import Annotated def generate_id() -> str: @@ -15,57 +14,6 @@ def generate_id() -> str: CONTROL_PLANE_NAME = "control_plane" -class ChatMessage(BaseModel): - """Chat message. - - TODO: Temp copy of class from llama-index, to avoid pydantic v1/v2 issues. - """ - - role: MessageRole = MessageRole.USER - content: Optional[Any] = "" - additional_kwargs: dict = Field(default_factory=dict) - - def __str__(self) -> str: - return f"{self.role.value}: {self.content}" - - @classmethod - def from_str( - cls, - content: str, - role: Union[MessageRole, str] = MessageRole.USER, - **kwargs: Any, - ) -> "ChatMessage": - if isinstance(role, str): - role = MessageRole(role) - return cls(role=role, content=content, **kwargs) - - def _recursive_serialization(self, value: Any) -> Any: - if isinstance(value, (V1BaseModel, BaseModel)): - return value.dict() - if isinstance(value, dict): - return { - key: self._recursive_serialization(value) - for key, value in value.items() - } - if isinstance(value, list): - return [self._recursive_serialization(item) for item in value] - return value - - def dict(self, **kwargs: Any) -> dict: - # ensure all additional_kwargs are serializable - msg = super().dict(**kwargs) - - for key, value in msg.get("additional_kwargs", {}).items(): - value = self._recursive_serialization(value) - if not isinstance(value, (str, int, float, bool, dict, list, type(None))): - raise ValueError( - f"Failed to serialize additional_kwargs value: {value}" - ) - msg["additional_kwargs"][key] = value - - return msg - - class ActionTypes(str, Enum): """ Action types for messages. @@ -99,8 +47,8 @@ class TaskDefinition(BaseModel): input: str task_id: str = Field(default_factory=generate_id) - session_id: Optional[str] = None - agent_id: Optional[str] = None + session_id: str | None = None + agent_id: str | None = None class SessionDefinition(BaseModel): @@ -110,18 +58,18 @@ class SessionDefinition(BaseModel): Attributes: session_id (str): The session ID. Defaults to a random UUID. - task_definitions (List[str]): + task_definitions (list[str]): The task ids in order, representing the session. state (dict): The current session state. """ session_id: str = Field(default_factory=generate_id) - task_ids: List[str] = Field(default_factory=list) + task_ids: list[str] = Field(default_factory=list) state: dict = Field(default_factory=dict) @property - def current_task_id(self) -> Optional[str]: + def current_task_id(self) -> str | None: if len(self.task_ids) == 0: return None @@ -149,7 +97,7 @@ class TaskResult(BaseModel): Attributes: task_id (str): The task ID. - history (List[ChatMessage]): + history (list[ChatMessage]): The task history. result (str): The task result. @@ -162,7 +110,7 @@ class TaskResult(BaseModel): """ task_id: str - history: List[ChatMessage] + history: list[ChatMessage] result: str data: dict = Field(default_factory=dict) @@ -174,14 +122,14 @@ class TaskStream(BaseModel): Attributes: task_id (str): The associated task ID. - data (List[dict]): + data (list[dict]): The stream data. index (int): The index of the stream data. """ task_id: str - session_id: Optional[str] + session_id: str | None data: dict index: int @@ -193,15 +141,15 @@ class ToolCallBundle(BaseModel): Attributes: tool_name (str): The name of the tool. - tool_args (List[Any]): + tool_args (list[Any]): The tool arguments. - tool_kwargs (Dict[str, Any]): + tool_kwargs (dict[str, Any]): The tool keyword arguments """ tool_name: str - tool_args: List[Any] - tool_kwargs: Dict[str, Any] + tool_args: list[Any] + tool_kwargs: dict[str, Any] class ToolCall(BaseModel): @@ -249,11 +197,11 @@ class ServiceDefinition(BaseModel): The name of the service. description (str): A description of the service and it's purpose. - prompt (List[ChatMessage]): + prompt (list[ChatMessage]): Specific instructions for the service. - host (Optional[str]): + host (str | None): The host of the service, if its a network service. - port (Optional[int]): + port (int | None): The port of the service, if its a network service. """ @@ -261,11 +209,11 @@ class ServiceDefinition(BaseModel): description: str = Field( description="A description of the service and it's purpose." ) - prompt: List[ChatMessage] = Field( + prompt: list[ChatMessage] = Field( default_factory=list, description="Specific instructions for the service." ) - host: Optional[str] = None - port: Optional[int] = None + host: str | None = None + port: int | None = None class HumanResponse(BaseModel): diff --git a/llama_deploy/utils.py b/llama_deploy/utils.py deleted file mode 100644 index b804e232..00000000 --- a/llama_deploy/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -from string import Formatter -from typing import List - - -def get_prompt_params(prompt_template_str: str) -> List[str]: - """Get the list of prompt params from the template format string.""" - return [param for _, param, _, _ in Formatter().parse(prompt_template_str) if param] diff --git a/tests/orchestrators/test_utils.py b/tests/orchestrators/test_utils.py new file mode 100644 index 00000000..efacb326 --- /dev/null +++ b/tests/orchestrators/test_utils.py @@ -0,0 +1,9 @@ +from llama_deploy.orchestrators.utils import get_result_key, get_stream_key + + +def test_get_result_key() -> None: + assert get_result_key("test_task") == "result_test_task" + + +def test_get_stream_key() -> None: + assert get_stream_key("test_task") == "stream_test_task"