diff --git a/src/aleph/db/accessors/messages.py b/src/aleph/db/accessors/messages.py index 08be2a457..cbccee4af 100644 --- a/src/aleph/db/accessors/messages.py +++ b/src/aleph/db/accessors/messages.py @@ -52,6 +52,7 @@ def make_matching_messages_query( refs: Optional[Sequence[str]] = None, chains: Optional[Sequence[Chain]] = None, message_type: Optional[MessageType] = None, + message_types: Optional[Sequence[MessageType]] = None, start_date: Optional[Union[float, dt.datetime]] = None, end_date: Optional[Union[float, dt.datetime]] = None, content_hashes: Optional[Sequence[ItemHash]] = None, @@ -87,6 +88,8 @@ def make_matching_messages_query( select_stmt = select_stmt.where(MessageDb.sender.in_(addresses)) if chains: select_stmt = select_stmt.where(MessageDb.chain.in_(chains)) + if message_types: + select_stmt = select_stmt.where(MessageDb.type.in_(message_types)) if message_type: select_stmt = select_stmt.where(MessageDb.type == message_type) if start_datetime: diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 6d1f409b1..47d49655a 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -67,7 +67,10 @@ class BaseMessageQueryParams(BaseModel): "-1 means most recent messages first, 1 means older messages first.", ) message_type: Optional[MessageType] = Field( - default=None, alias="msgType", description="Message type." + default=None, alias="msgType", description="Message type. Deprecated: use msgTypes instead" + ) + message_types: Optional[List[MessageType]] = Field( + default=None, alias="msgTypes", description="Accepted message types." ) addresses: Optional[List[str]] = Field( default=None, description="Accepted values for the 'sender' field." @@ -120,6 +123,7 @@ def validate_field_dependencies(cls, values): "content_types", "chains", "channels", + "message_types", "tags", pre=True, ) @@ -356,7 +360,6 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: except ValidationError as e: raise web.HTTPUnprocessableEntity(body=e.json(indent=4)) - message_filters = query_params.dict(exclude_none=True) history = query_params.history if history: diff --git a/tests/api/test_list_messages.py b/tests/api/test_list_messages.py index 5d97d4e97..69b4534f1 100644 --- a/tests/api/test_list_messages.py +++ b/tests/api/test_list_messages.py @@ -1,5 +1,6 @@ import datetime as dt import itertools +from collections import defaultdict from typing import Any, Dict, Iterable, List, Optional, Sequence, Union, Tuple import aiohttp @@ -221,6 +222,24 @@ async def test_get_messages_filter_by_tags( assert messages[0]["item_hash"] == amend_message_db.item_hash +@pytest.mark.parametrize("type_field", ("msgType", "msgTypes")) +@pytest.mark.asyncio +async def test_get_by_message_type(fixture_messages, ccn_api_client, type_field: str): + messages_by_type = defaultdict(list) + for message in fixture_messages: + messages_by_type[message["type"]].append(message) + + for message_type, expected_messages in messages_by_type.items(): + response = await ccn_api_client.get( + MESSAGES_URI, params={type_field: message_type} + ) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert set(msg["item_hash"] for msg in messages) == set( + msg["item_hash"] for msg in expected_messages + ) + + @pytest.mark.asyncio async def test_get_messages_filter_by_tags_no_match(fixture_messages, ccn_api_client): """