Skip to content

Commit

Permalink
MessageQueue items (#1)
Browse files Browse the repository at this point in the history
* scaffold

* fmt

* wip

* add simple

* test for simple mq

* add different consumer type to test

* deregister test

* add publish test

* infinite loop for passing messages

* lock to privateattr

* minor fixes

---------

Co-authored-by: Logan Markewich <[email protected]>
  • Loading branch information
nerdai and logan-markewich authored Jun 6, 2024
1 parent 68e19ee commit a562d7c
Show file tree
Hide file tree
Showing 15 changed files with 332 additions and 2 deletions.
12 changes: 12 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"[python]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.fixAll": "explicit"
},
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
2 changes: 1 addition & 1 deletion agentfile/agent_server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,6 @@ class AgentRole(BaseModel):
default_factory=list, description="Specific instructions for the agent."
)
agent_id: str = Field(
default_factory=str(uuid.uuid4()),
default_factory=lambda: str(uuid.uuid4()),
description="A unique identifier for the agent.",
)
Empty file.
36 changes: 36 additions & 0 deletions agentfile/message_consumers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Message consumers."""

import uuid
from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING
from llama_index.core.bridge.pydantic import BaseModel, Field
from agentfile.messages.base import BaseMessage

if TYPE_CHECKING:
from agentfile.message_queues.base import BaseMessageQueue


class BaseMessageQueueConsumer(BaseModel, ABC):
"""Consumer of a MessageQueue."""

id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
message_type: Type[BaseMessage] = Field(default=BaseMessage)

class Config:
arbitrary_types_allowed = True

@abstractmethod
async def _process_message(self, message: BaseMessage, **kwargs: Any) -> Any:
"""Subclasses should implement logic here."""

async def process_message(self, message: BaseMessage, **kwargs: Any) -> Any:
"""Logic for processing message."""
if not isinstance(message, self.message_type):
raise ValueError("Consumer cannot process the given kind of Message.")
return await self._process_message(message, **kwargs)

async def start_consuming(
self, message_queue: "BaseMessageQueue", **kwargs: Any
) -> None:
"""Begin consuming messages."""
await message_queue.register_consumer(self, **kwargs)
Empty file.
50 changes: 50 additions & 0 deletions agentfile/message_queues/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Message queue module."""

from abc import ABC, abstractmethod
from typing import Any, List, Protocol, Type, TYPE_CHECKING
from llama_index.core.bridge.pydantic import BaseModel
from agentfile.messages.base import BaseMessage

if TYPE_CHECKING:
from agentfile.message_consumers.base import BaseMessageQueueConsumer


class MessageProcessor(Protocol):
"""Protocol for a callable that processes messages."""

def __call__(self, message: BaseMessage, **kwargs: Any) -> None:
...


class BaseMessageQueue(BaseModel, ABC):
"""Message broker interface between publisher and consumer."""

class Config:
arbitrary_types_allowed = True

@abstractmethod
async def _publish(self, message: BaseMessage, **kwargs: Any) -> Any:
"""Subclasses implement publish logic here."""
...

async def publish(self, message: BaseMessage, **kwargs: Any) -> Any:
"""Send message to a consumer."""
await self._publish(message, **kwargs)

@abstractmethod
async def register_consumer(
self, consumer: "BaseMessageQueueConsumer", **kwargs: Any
) -> Any:
"""Register consumer to start consuming messages."""

@abstractmethod
async def deregister_consumer(self, consumer: "BaseMessageQueueConsumer") -> Any:
"""Deregister consumer to stop publishing messages)."""

async def get_consumers(
self, message_type: Type[BaseMessage]
) -> List["BaseMessageQueueConsumer"]:
"""Gets list of consumers according to a message type."""
raise NotImplementedError(
"`get_consumers()` is not implemented for this class."
)
103 changes: 103 additions & 0 deletions agentfile/message_queues/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Simple Message Queue."""

import asyncio
import random

from collections import deque
from typing import Any, Dict, List, Type
from llama_index.core.bridge.pydantic import Field
from agentfile.message_queues.base import BaseMessageQueue
from agentfile.messages.base import BaseMessage
from agentfile.message_consumers.base import BaseMessageQueueConsumer


class SimpleMessageQueue(BaseMessageQueue):
"""SimpleMessageQueue.
An in-memory message queue that implements a push model for consumers.
"""

consumers: Dict[str, Dict[str, BaseMessageQueueConsumer]] = Field(
default_factory=dict
)
queues: Dict[str, deque] = Field(default_factory=dict)
running: bool = True

def __init__(
self,
consumers: Dict[str, Dict[str, BaseMessageQueueConsumer]] = {},
queues: Dict[str, deque] = {},
):
super().__init__(consumers=consumers, queues=queues)

def _select_consumer(self, message: BaseMessage) -> BaseMessageQueueConsumer:
"""Select a single consumer to publish a message to."""
message_type_str = message.class_name()
consumer_id = random.choice(list(self.consumers[message_type_str].keys()))
return self.consumers[message_type_str][consumer_id]

async def _publish(self, message: BaseMessage, **kwargs: Any) -> Any:
"""Publish message to a queue."""
message_type_str = message.class_name()

if message_type_str not in self.consumers:
raise ValueError(f"No consumer for {message_type_str} has been registered.")

if message_type_str not in self.queues:
self.queues[message_type_str] = deque()

self.queues[message_type_str].append(message)

async def _publish_to_consumer(self, message: BaseMessage, **kwargs: Any) -> Any:
"""Publish message to a consumer."""
consumer = self._select_consumer(message)
try:
await consumer.process_message(message, **kwargs)
except Exception:
raise

async def start(self) -> None:
"""A loop for getting messages from queues and sending to consumer."""
while self.running:
print(self.queues)
for queue in self.queues.values():
if queue:
message = queue.popleft()
await self._publish_to_consumer(message)
print(self.queues)
await asyncio.sleep(0.1)

async def register_consumer(
self, consumer: BaseMessageQueueConsumer, **kwargs: Any
) -> None:
"""Register a new consumer."""
message_type_str = consumer.message_type.class_name()

if message_type_str not in self.consumers:
self.consumers[message_type_str] = {consumer.id_: consumer}
else:
if consumer.id_ in self.consumers[message_type_str]:
raise ValueError("Consumer has already been added.")

self.consumers[message_type_str][consumer.id_] = consumer

if message_type_str not in self.queues:
self.queues[message_type_str] = deque()

async def deregister_consumer(self, consumer: BaseMessageQueueConsumer) -> None:
message_type_str = consumer.message_type.class_name()
if consumer.id_ not in self.consumers[message_type_str]:
raise ValueError("No consumer found for associated message type.")

del self.consumers[message_type_str][consumer.id_]
if len(self.consumers[message_type_str]) == 0:
del self.consumers[message_type_str]

async def get_consumers(
self, message_type: Type[BaseMessage]
) -> List[BaseMessageQueueConsumer]:
message_type_str = message_type.class_name()
if message_type_str not in self.consumers:
return []

return list(self.consumers[message_type_str].values())
Empty file added agentfile/messages/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions agentfile/messages/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Base Message."""

import uuid
from typing import Any, Optional
from llama_index.core.bridge.pydantic import BaseModel, Field


class BaseMessage(BaseModel):
id_: str = Field(default_factory=lambda: str(uuid.uuid4))
data: Optional[Any] = Field(default_factory=None)

@classmethod
def class_name(cls) -> str:
"""Class name."""
return "BaseMessage"

class Config:
arbitrary_types_allowed = True
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ fastapi = "^0.111.0"
llama-index-core = "^0.10.40"
llama-index-agent-openai = "^0.2.5"
llama-index-embeddings-openai = "^0.1.10"
pytest-asyncio = "^0.23.7"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
Expand Down
Binary file added tests/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Empty file.
92 changes: 92 additions & 0 deletions tests/message_queues/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import asyncio
import pytest
from typing import Any, List
from llama_index.core.bridge.pydantic import PrivateAttr
from agentfile.message_consumers.base import BaseMessageQueueConsumer
from agentfile.message_queues.simple import SimpleMessageQueue
from agentfile.messages.base import BaseMessage


class MockMessageConsumer(BaseMessageQueueConsumer):
processed_messages: List[BaseMessage] = []
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)

async def _process_message(self, message: BaseMessage, **kwargs: Any) -> None:
async with self._lock:
self.processed_messages.append(message)


class MockMessage(BaseMessage):
@classmethod
def class_name(cls) -> str:
return "MockMessage"


@pytest.mark.asyncio()
async def test_simple_register_consumer() -> None:
# Arrange
consumer_one = MockMessageConsumer()
consumer_two = MockMessageConsumer(message_type=MockMessage)
mq = SimpleMessageQueue()

# Act
await mq.register_consumer(consumer_one)
await mq.register_consumer(consumer_two)
with pytest.raises(ValueError):
await mq.register_consumer(consumer_two)

# Assert
assert consumer_one.id_ in [
c.id_ for c in await mq.get_consumers(consumer_one.message_type)
]
assert consumer_two.id_ in [
c.id_ for c in await mq.get_consumers(consumer_two.message_type)
]


@pytest.mark.asyncio()
async def test_simple_deregister_consumer() -> None:
# Arrange
consumer_one = MockMessageConsumer()
consumer_two = MockMessageConsumer(message_type=MockMessage)
consumer_three = MockMessageConsumer(message_type=MockMessage)
mq = SimpleMessageQueue()

await mq.register_consumer(consumer_one)
await mq.register_consumer(consumer_two)
await mq.register_consumer(consumer_three)

# Act
await mq.deregister_consumer(consumer_one)
await mq.deregister_consumer(consumer_three)
with pytest.raises(ValueError):
await mq.deregister_consumer(consumer_three)

# Assert
assert len(await mq.get_consumers(MockMessage)) == 1
assert len(await mq.get_consumers(BaseMessage)) == 0


@pytest.mark.asyncio()
async def test_simple_publish_consumer() -> None:
# Arrange
consumer_one = MockMessageConsumer()
consumer_two = MockMessageConsumer(message_type=MockMessage)
mq = SimpleMessageQueue()
task = asyncio.create_task(mq.start())

await mq.register_consumer(consumer_one)
await mq.register_consumer(consumer_two)

# Act
await mq.publish(BaseMessage(id_="1"))
await mq.publish(MockMessage(id_="2"))
await mq.publish(MockMessage(id_="3"))

# Give some time for last message to get published and sent to consumers
await asyncio.sleep(0.5)
task.cancel()

# Assert
assert ["1"] == [m.id_ for m in consumer_one.processed_messages]
assert ["2", "3"] == [m.id_ for m in consumer_two.processed_messages]

0 comments on commit a562d7c

Please sign in to comment.