Skip to content

Commit

Permalink
Unit tests for remote client mq (#32)
Browse files Browse the repository at this point in the history
* wip

* add register_consumer unit test

* add unit test for deregister consumer

* dry

* change get_consumers in simple to path param

* unit test for publish; move up publish_time stats for QueueMessage

* update test

* mv and rename remote client to message_queues.simple

* cr
  • Loading branch information
nerdai authored Jun 23, 2024
1 parent e012077 commit 94ac16e
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 87 deletions.
2 changes: 1 addition & 1 deletion llama_agents/message_queues/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ async def publish(
) -> Any:
"""Send message to a consumer."""
logger.info("Publishing message: " + str(message))
await self._publish(message)
message.stats.publish_time = message.stats.timestamp_str()
await self._publish(message)

if callback:
if inspect.iscoroutinefunction(callback):
Expand Down
76 changes: 0 additions & 76 deletions llama_agents/message_queues/remote_client.py

This file was deleted.

94 changes: 89 additions & 5 deletions llama_agents/message_queues/simple.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Simple Message Queue."""

import asyncio
import httpx
import random
import logging
import uvicorn
Expand All @@ -9,7 +10,8 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pydantic import Field, PrivateAttr
from typing import Any, AsyncGenerator, Dict, List
from typing import Any, AsyncGenerator, Dict, List, Optional
from urllib.parse import urljoin

from llama_agents.message_queues.base import BaseMessageQueue
from llama_agents.messages.base import QueueMessage
Expand All @@ -18,12 +20,96 @@
RemoteMessageConsumer,
RemoteMessageConsumerDef,
)
from llama_agents.types import PydanticValidatedUrl

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)


class SimpleRemoteClientMessageQueue(BaseMessageQueue):
"""Remote client to be used with a SimpleMessageQueue server."""

base_url: PydanticValidatedUrl
client_kwargs: Optional[Dict] = None
client: Optional[httpx.AsyncClient] = None

async def _publish(
self, message: QueueMessage, publish_url: str = "publish", **kwargs: Any
) -> Any:
client_kwargs = self.client_kwargs or {}
client = self.client or httpx.AsyncClient(**client_kwargs)
url = urljoin(self.base_url, publish_url)
async with httpx.AsyncClient() as client:
result = await client.post(url, json=message.model_dump())
return result

async def register_consumer(
self,
consumer: BaseMessageQueueConsumer,
register_consumer_url: str = "register_consumer",
**kwargs: Any,
) -> httpx.Response:
client_kwargs = self.client_kwargs or {}
client = self.client or httpx.AsyncClient(**client_kwargs)
url = urljoin(self.base_url, register_consumer_url)
try:
remote_consumer_def = RemoteMessageConsumerDef(**consumer.model_dump())
except Exception as e:
raise ValueError(
"Unable to convert consumer to RemoteMessageConsumer"
) from e
async with httpx.AsyncClient() as client:
result = await client.post(url, json=remote_consumer_def.model_dump())
return result

async def deregister_consumer(
self,
consumer: BaseMessageQueueConsumer,
deregister_consumer_url: str = "deregister_consumer",
) -> Any:
client_kwargs = self.client_kwargs or {}
client = self.client or httpx.AsyncClient(**client_kwargs)
url = urljoin(self.base_url, deregister_consumer_url)
try:
remote_consumer_def = RemoteMessageConsumerDef(**consumer.model_dump())
except Exception as e:
raise ValueError(
"Unable to convert consumer to RemoteMessageConsumer"
) from e
async with httpx.AsyncClient() as client:
result = await client.post(url, json=remote_consumer_def.model_dump())
return result

async def get_consumers(
self, message_type: str, get_consumers_url: str = "get_consumers"
) -> List[BaseMessageQueueConsumer]:
client_kwargs = self.client_kwargs or {}
client = self.client or httpx.AsyncClient(**client_kwargs)
url = urljoin(self.base_url, f"{get_consumers_url}/{message_type}")
async with httpx.AsyncClient() as client:
res = await client.get(url)
if res.status_code == 200:
remote_consumer_defs = res.json()
consumers = [RemoteMessageConsumer(**el) for el in remote_consumer_defs]
else:
consumers = []
return consumers

async def processing_loop(self) -> None:
raise NotImplementedError(
"`procesing_loop()` is not implemented for this class."
)

async def launch_local(self) -> None:
raise NotImplementedError("`launch_local()` is not implemented for this class.")

async def launch_server(self) -> None:
raise NotImplementedError(
"`launch_server()` is not implemented for this class."
)


class SimpleMessageQueue(BaseMessageQueue):
"""SimpleMessageQueue.
Expand Down Expand Up @@ -66,7 +152,7 @@ def __init__(
)

self._app.add_api_route(
"/get_consumers",
"/get_consumers/{message_type}",
self.get_consumer_defs,
methods=["GET"],
tags=["Consumers"],
Expand All @@ -81,10 +167,8 @@ def __init__(

@property
def client(self) -> BaseMessageQueue:
from llama_agents.message_queues.remote_client import RemoteClientMessageQueue

base_url = f"http://{self.host}:{self.port}"
return RemoteClientMessageQueue(base_url=base_url)
return SimpleRemoteClientMessageQueue(base_url=base_url)

def _select_consumer(self, message: QueueMessage) -> BaseMessageQueueConsumer:
"""Select a single consumer to publish a message to."""
Expand Down
9 changes: 8 additions & 1 deletion llama_agents/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import uuid
from enum import Enum
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, BeforeValidator, HttpUrl, TypeAdapter
from pydantic.v1 import BaseModel as V1BaseModel
from typing import Any, Dict, List, Optional, Union
from typing_extensions import Annotated

from llama_index.core.llms import MessageRole

Expand Down Expand Up @@ -118,3 +119,9 @@ class ServiceDefinition(BaseModel):

class HumanResponse(BaseModel):
result: str


http_url_adapter = TypeAdapter(HttpUrl)
PydanticValidatedUrl = Annotated[
str, BeforeValidator(lambda value: str(http_url_adapter.validate_python(value)))
]
4 changes: 1 addition & 3 deletions tests/message_queues/test_simple_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_get_consumers() -> None:

# act
_ = test_client.post("/register_consumer", json=remote_consumer_def.model_dump())
response = test_client.get("/get_consumers/?message_type=mock_type")
response = test_client.get("/get_consumers/mock_type")

# assert
assert response.status_code == 200
Expand All @@ -89,10 +89,8 @@ async def test_publish() -> None:
message = QueueMessage(
type="mock_type", data={"payload": "mock payload"}, action=ActionTypes.NEW_TASK
)
print(message.model_dump())
response = test_client.post("/publish", json=message.model_dump())

# assert
print(response.content)
assert response.status_code == 200
assert mq.queues["mock_type"][0] == message
Loading

0 comments on commit 94ac16e

Please sign in to comment.