Skip to content

Commit

Permalink
fix mixin (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai authored Jun 10, 2024
1 parent 0dc0d27 commit 91e6ef2
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 19 deletions.
16 changes: 13 additions & 3 deletions agentfile/agent_server/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import uuid
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
Expand Down Expand Up @@ -57,10 +58,11 @@ def __init__(
agent_id=agent_id, description=description
)
self.agent = agent
self.message_queue = message_queue
self._message_queue = message_queue
self.description = description
self.running = running
self.step_interval = step_interval
self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}"

self.app = FastAPI(lifespan=self.lifespan)

Expand Down Expand Up @@ -126,6 +128,14 @@ def __init__(
def agent_definition(self) -> AgentDefinition:
return self._agent_definition

@property
def message_queue(self) -> BaseMessageQueue:
return self._message_queue

@property
def publisher_id(self) -> str:
return self._publisher_id

@asynccontextmanager
async def lifespan(self, app: FastAPI) -> AsyncGenerator[None, None]:
logger.info("Starting up")
Expand Down Expand Up @@ -175,9 +185,9 @@ async def start_processing_loop(self) -> None:
response = self.agent.finalize_response(
task_id, step_output=step_output
)
await self.message_queue.publish(
await self.publish(
QueueMessage(
source_id=self.id_,
source_id=self.publisher_id,
type="control_plane",
action=ActionTypes.COMPLETED_TASK,
data=TaskResult(
Expand Down
20 changes: 15 additions & 5 deletions agentfile/control_plane/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
import uvicorn
from fastapi import FastAPI
from typing import Any, Callable, Dict, Optional
Expand Down Expand Up @@ -61,7 +62,8 @@ def __init__(
self.active_flows_store_key = active_flows_store_key
self.tasks_store_key = tasks_store_key

self.message_queue = message_queue
self._message_queue = message_queue
self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}"

self.app = FastAPI()
self.app.add_api_route("/", self.home, methods=["GET"], tags=["Control Plane"])
Expand Down Expand Up @@ -90,6 +92,14 @@ def __init__(
"/tasks/{task_id}", self.get_task_state, methods=["GET"], tags=["Tasks"]
)

@property
def message_queue(self) -> BaseMessageQueue:
return self._message_queue

@property
def publisher_id(self) -> str:
return self._publisher_id

def get_consumer(self) -> BaseMessageQueueConsumer:
return ControlPlaneMessageConsumer(
message_handler={
Expand Down Expand Up @@ -155,9 +165,9 @@ async def send_task_to_agent(self, task_def: TaskDefinition) -> None:
all_agents = await self.state_store.aget_all(collection=self.agents_store_key)
agent_id = task_def.agent_id or list(all_agents.keys())[0]

await self.message_queue.publish(
await self.publish(
QueueMessage(
source_id=self.id_,
source_id=self.publisher_id,
type=agent_id,
data=task_def.model_dump(),
action=ActionTypes.NEW_TASK,
Expand All @@ -174,9 +184,9 @@ async def handle_agent_completion(
task_result.task_id, collection=self.tasks_store_key
)

await self.message_queue.publish(
await self.publish(
QueueMessage(
source_id=self.id_,
source_id=self.publisher_id,
type="human",
action=ActionTypes.COMPLETED_TASK,
data=task_result.result,
Expand Down
16 changes: 13 additions & 3 deletions agentfile/launchers/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import uuid
from typing import Any, Callable, Dict, List, Optional

from agentfile.agent_server.base import BaseAgentServer
Expand Down Expand Up @@ -32,7 +33,16 @@ def __init__(
) -> None:
self.agent_servers = agent_servers
self.control_plane = control_plane
self.message_queue = message_queue
self._message_queue = message_queue
self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}"

@property
def message_queue(self) -> SimpleMessageQueue:
return self._message_queue

@property
def publisher_id(self) -> str:
return self._publisher_id

async def handle_human_message(self, **kwargs: Any) -> None:
print("Got response:\n", str(kwargs), flush=True)
Expand Down Expand Up @@ -63,9 +73,9 @@ async def alaunch_single(self, initial_task: str) -> None:
await self.register_consumers([human_consumer])

# publish initial task
await self.message_queue.publish(
await self.publish(
QueueMessage(
source_id=self.id_,
source_id=self.publisher_id,
type="control_plane",
action=ActionTypes.NEW_TASK,
data=TaskDefinition(input=initial_task).model_dump(),
Expand Down
22 changes: 16 additions & 6 deletions agentfile/message_publishers/publisher.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import uuid
from abc import ABC
from abc import ABC, abstractmethod
from typing import Any
from agentfile.messages.base import QueueMessage
from agentfile.message_queues.base import BaseMessageQueue


class MessageQueuePublisherMixin(ABC):
"""PublisherMixing."""

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
if not hasattr(self, "id_"):
self.id_ = f"{self.__class__.__qualname__}-{uuid.uuid4()}"
@property
@abstractmethod
def publisher_id(self) -> str:
...

@property
@abstractmethod
def message_queue(self) -> BaseMessageQueue:
...

async def publish(self, message: QueueMessage, **kwargs: Any) -> Any:
"""Publish message."""
return await self.message_queue.publish(message, **kwargs)
11 changes: 9 additions & 2 deletions tests/test_agent_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from agentfile.agent_server import FastAPIAgentServer
from agentfile.message_queues.simple import SimpleMessageQueue

from unittest.mock import patch, MagicMock

def test_init() -> None:

@patch("agentfile.agent_server.fastapi.uuid")
def test_init(mock_uuid: MagicMock) -> None:
mock_uuid.uuid4.return_value = "mock"
agent = ReActAgent.from_tools([], llm=MockLLM())
mq = SimpleMessageQueue()
server = FastAPIAgentServer(
agent,
SimpleMessageQueue(),
mq,
running=False,
description="Test Agent Server",
step_interval=0.5,
Expand All @@ -19,3 +24,5 @@ def test_init() -> None:
assert server.running is False
assert server.description == "Test Agent Server"
assert server.step_interval == 0.5
assert server.message_queue == mq
assert server.publisher_id == "FastAPIAgentServer-mock"

0 comments on commit 91e6ef2

Please sign in to comment.