-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* scaffold * fmt * wip * add simple * test for simple mq * add different consumer type to test * deregister test * add publish test * infinite loop for passing messages * lock to privateattr * minor fixes --------- Co-authored-by: Logan Markewich <[email protected]>
- Loading branch information
1 parent
68e19ee
commit a562d7c
Showing
15 changed files
with
332 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"[python]": { | ||
"editor.formatOnSave": true, | ||
"editor.codeActionsOnSave": { | ||
"source.fixAll": "explicit" | ||
}, | ||
"editor.defaultFormatter": "ms-python.black-formatter" | ||
}, | ||
"python.testing.pytestArgs": ["tests"], | ||
"python.testing.unittestEnabled": false, | ||
"python.testing.pytestEnabled": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""Message consumers.""" | ||
|
||
import uuid | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Type, TYPE_CHECKING | ||
from llama_index.core.bridge.pydantic import BaseModel, Field | ||
from agentfile.messages.base import BaseMessage | ||
|
||
if TYPE_CHECKING: | ||
from agentfile.message_queues.base import BaseMessageQueue | ||
|
||
|
||
class BaseMessageQueueConsumer(BaseModel, ABC): | ||
"""Consumer of a MessageQueue.""" | ||
|
||
id_: str = Field(default_factory=lambda: str(uuid.uuid4())) | ||
message_type: Type[BaseMessage] = Field(default=BaseMessage) | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
@abstractmethod | ||
async def _process_message(self, message: BaseMessage, **kwargs: Any) -> Any: | ||
"""Subclasses should implement logic here.""" | ||
|
||
async def process_message(self, message: BaseMessage, **kwargs: Any) -> Any: | ||
"""Logic for processing message.""" | ||
if not isinstance(message, self.message_type): | ||
raise ValueError("Consumer cannot process the given kind of Message.") | ||
return await self._process_message(message, **kwargs) | ||
|
||
async def start_consuming( | ||
self, message_queue: "BaseMessageQueue", **kwargs: Any | ||
) -> None: | ||
"""Begin consuming messages.""" | ||
await message_queue.register_consumer(self, **kwargs) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""Message queue module.""" | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any, List, Protocol, Type, TYPE_CHECKING | ||
from llama_index.core.bridge.pydantic import BaseModel | ||
from agentfile.messages.base import BaseMessage | ||
|
||
if TYPE_CHECKING: | ||
from agentfile.message_consumers.base import BaseMessageQueueConsumer | ||
|
||
|
||
class MessageProcessor(Protocol): | ||
"""Protocol for a callable that processes messages.""" | ||
|
||
def __call__(self, message: BaseMessage, **kwargs: Any) -> None: | ||
... | ||
|
||
|
||
class BaseMessageQueue(BaseModel, ABC): | ||
"""Message broker interface between publisher and consumer.""" | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
@abstractmethod | ||
async def _publish(self, message: BaseMessage, **kwargs: Any) -> Any: | ||
"""Subclasses implement publish logic here.""" | ||
... | ||
|
||
async def publish(self, message: BaseMessage, **kwargs: Any) -> Any: | ||
"""Send message to a consumer.""" | ||
await self._publish(message, **kwargs) | ||
|
||
@abstractmethod | ||
async def register_consumer( | ||
self, consumer: "BaseMessageQueueConsumer", **kwargs: Any | ||
) -> Any: | ||
"""Register consumer to start consuming messages.""" | ||
|
||
@abstractmethod | ||
async def deregister_consumer(self, consumer: "BaseMessageQueueConsumer") -> Any: | ||
"""Deregister consumer to stop publishing messages).""" | ||
|
||
async def get_consumers( | ||
self, message_type: Type[BaseMessage] | ||
) -> List["BaseMessageQueueConsumer"]: | ||
"""Gets list of consumers according to a message type.""" | ||
raise NotImplementedError( | ||
"`get_consumers()` is not implemented for this class." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""Simple Message Queue.""" | ||
|
||
import asyncio | ||
import random | ||
|
||
from collections import deque | ||
from typing import Any, Dict, List, Type | ||
from llama_index.core.bridge.pydantic import Field | ||
from agentfile.message_queues.base import BaseMessageQueue | ||
from agentfile.messages.base import BaseMessage | ||
from agentfile.message_consumers.base import BaseMessageQueueConsumer | ||
|
||
|
||
class SimpleMessageQueue(BaseMessageQueue): | ||
"""SimpleMessageQueue. | ||
An in-memory message queue that implements a push model for consumers. | ||
""" | ||
|
||
consumers: Dict[str, Dict[str, BaseMessageQueueConsumer]] = Field( | ||
default_factory=dict | ||
) | ||
queues: Dict[str, deque] = Field(default_factory=dict) | ||
running: bool = True | ||
|
||
def __init__( | ||
self, | ||
consumers: Dict[str, Dict[str, BaseMessageQueueConsumer]] = {}, | ||
queues: Dict[str, deque] = {}, | ||
): | ||
super().__init__(consumers=consumers, queues=queues) | ||
|
||
def _select_consumer(self, message: BaseMessage) -> BaseMessageQueueConsumer: | ||
"""Select a single consumer to publish a message to.""" | ||
message_type_str = message.class_name() | ||
consumer_id = random.choice(list(self.consumers[message_type_str].keys())) | ||
return self.consumers[message_type_str][consumer_id] | ||
|
||
async def _publish(self, message: BaseMessage, **kwargs: Any) -> Any: | ||
"""Publish message to a queue.""" | ||
message_type_str = message.class_name() | ||
|
||
if message_type_str not in self.consumers: | ||
raise ValueError(f"No consumer for {message_type_str} has been registered.") | ||
|
||
if message_type_str not in self.queues: | ||
self.queues[message_type_str] = deque() | ||
|
||
self.queues[message_type_str].append(message) | ||
|
||
async def _publish_to_consumer(self, message: BaseMessage, **kwargs: Any) -> Any: | ||
"""Publish message to a consumer.""" | ||
consumer = self._select_consumer(message) | ||
try: | ||
await consumer.process_message(message, **kwargs) | ||
except Exception: | ||
raise | ||
|
||
async def start(self) -> None: | ||
"""A loop for getting messages from queues and sending to consumer.""" | ||
while self.running: | ||
print(self.queues) | ||
for queue in self.queues.values(): | ||
if queue: | ||
message = queue.popleft() | ||
await self._publish_to_consumer(message) | ||
print(self.queues) | ||
await asyncio.sleep(0.1) | ||
|
||
async def register_consumer( | ||
self, consumer: BaseMessageQueueConsumer, **kwargs: Any | ||
) -> None: | ||
"""Register a new consumer.""" | ||
message_type_str = consumer.message_type.class_name() | ||
|
||
if message_type_str not in self.consumers: | ||
self.consumers[message_type_str] = {consumer.id_: consumer} | ||
else: | ||
if consumer.id_ in self.consumers[message_type_str]: | ||
raise ValueError("Consumer has already been added.") | ||
|
||
self.consumers[message_type_str][consumer.id_] = consumer | ||
|
||
if message_type_str not in self.queues: | ||
self.queues[message_type_str] = deque() | ||
|
||
async def deregister_consumer(self, consumer: BaseMessageQueueConsumer) -> None: | ||
message_type_str = consumer.message_type.class_name() | ||
if consumer.id_ not in self.consumers[message_type_str]: | ||
raise ValueError("No consumer found for associated message type.") | ||
|
||
del self.consumers[message_type_str][consumer.id_] | ||
if len(self.consumers[message_type_str]) == 0: | ||
del self.consumers[message_type_str] | ||
|
||
async def get_consumers( | ||
self, message_type: Type[BaseMessage] | ||
) -> List[BaseMessageQueueConsumer]: | ||
message_type_str = message_type.class_name() | ||
if message_type_str not in self.consumers: | ||
return [] | ||
|
||
return list(self.consumers[message_type_str].values()) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Base Message.""" | ||
|
||
import uuid | ||
from typing import Any, Optional | ||
from llama_index.core.bridge.pydantic import BaseModel, Field | ||
|
||
|
||
class BaseMessage(BaseModel): | ||
id_: str = Field(default_factory=lambda: str(uuid.uuid4)) | ||
data: Optional[Any] = Field(default_factory=None) | ||
|
||
@classmethod | ||
def class_name(cls) -> str: | ||
"""Class name.""" | ||
return "BaseMessage" | ||
|
||
class Config: | ||
arbitrary_types_allowed = True |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import asyncio | ||
import pytest | ||
from typing import Any, List | ||
from llama_index.core.bridge.pydantic import PrivateAttr | ||
from agentfile.message_consumers.base import BaseMessageQueueConsumer | ||
from agentfile.message_queues.simple import SimpleMessageQueue | ||
from agentfile.messages.base import BaseMessage | ||
|
||
|
||
class MockMessageConsumer(BaseMessageQueueConsumer): | ||
processed_messages: List[BaseMessage] = [] | ||
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) | ||
|
||
async def _process_message(self, message: BaseMessage, **kwargs: Any) -> None: | ||
async with self._lock: | ||
self.processed_messages.append(message) | ||
|
||
|
||
class MockMessage(BaseMessage): | ||
@classmethod | ||
def class_name(cls) -> str: | ||
return "MockMessage" | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_simple_register_consumer() -> None: | ||
# Arrange | ||
consumer_one = MockMessageConsumer() | ||
consumer_two = MockMessageConsumer(message_type=MockMessage) | ||
mq = SimpleMessageQueue() | ||
|
||
# Act | ||
await mq.register_consumer(consumer_one) | ||
await mq.register_consumer(consumer_two) | ||
with pytest.raises(ValueError): | ||
await mq.register_consumer(consumer_two) | ||
|
||
# Assert | ||
assert consumer_one.id_ in [ | ||
c.id_ for c in await mq.get_consumers(consumer_one.message_type) | ||
] | ||
assert consumer_two.id_ in [ | ||
c.id_ for c in await mq.get_consumers(consumer_two.message_type) | ||
] | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_simple_deregister_consumer() -> None: | ||
# Arrange | ||
consumer_one = MockMessageConsumer() | ||
consumer_two = MockMessageConsumer(message_type=MockMessage) | ||
consumer_three = MockMessageConsumer(message_type=MockMessage) | ||
mq = SimpleMessageQueue() | ||
|
||
await mq.register_consumer(consumer_one) | ||
await mq.register_consumer(consumer_two) | ||
await mq.register_consumer(consumer_three) | ||
|
||
# Act | ||
await mq.deregister_consumer(consumer_one) | ||
await mq.deregister_consumer(consumer_three) | ||
with pytest.raises(ValueError): | ||
await mq.deregister_consumer(consumer_three) | ||
|
||
# Assert | ||
assert len(await mq.get_consumers(MockMessage)) == 1 | ||
assert len(await mq.get_consumers(BaseMessage)) == 0 | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_simple_publish_consumer() -> None: | ||
# Arrange | ||
consumer_one = MockMessageConsumer() | ||
consumer_two = MockMessageConsumer(message_type=MockMessage) | ||
mq = SimpleMessageQueue() | ||
task = asyncio.create_task(mq.start()) | ||
|
||
await mq.register_consumer(consumer_one) | ||
await mq.register_consumer(consumer_two) | ||
|
||
# Act | ||
await mq.publish(BaseMessage(id_="1")) | ||
await mq.publish(MockMessage(id_="2")) | ||
await mq.publish(MockMessage(id_="3")) | ||
|
||
# Give some time for last message to get published and sent to consumers | ||
await asyncio.sleep(0.5) | ||
task.cancel() | ||
|
||
# Assert | ||
assert ["1"] == [m.id_ for m in consumer_one.processed_messages] | ||
assert ["2", "3"] == [m.id_ for m in consumer_two.processed_messages] |