From ec343268cb9583359dacf1fbe46aefe2950a6efa Mon Sep 17 00:00:00 2001 From: Mike Hukiewitz Date: Sun, 28 May 2023 15:06:25 +0200 Subject: [PATCH 1/4] Problem: the msgType parameter when querying messages only accepts a single message type, where other params accept a list of values Solution: add msgTypes field and deprecate msgType --- src/aleph/db/accessors/messages.py | 11 ++++++++++- src/aleph/web/controllers/messages.py | 5 ++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/aleph/db/accessors/messages.py b/src/aleph/db/accessors/messages.py index 08be2a457..887f81bfd 100644 --- a/src/aleph/db/accessors/messages.py +++ b/src/aleph/db/accessors/messages.py @@ -1,5 +1,6 @@ import datetime as dt import traceback +import warnings from typing import Optional, Sequence, Union, Iterable, Any, Mapping, overload, Tuple from aleph_message.models import ItemHash, Chain, MessageType @@ -52,6 +53,7 @@ def make_matching_messages_query( refs: Optional[Sequence[str]] = None, chains: Optional[Sequence[Chain]] = None, message_type: Optional[MessageType] = None, + message_types: Optional[Sequence[MessageType]] = None, start_date: Optional[Union[float, dt.datetime]] = None, end_date: Optional[Union[float, dt.datetime]] = None, content_hashes: Optional[Sequence[ItemHash]] = None, @@ -87,7 +89,15 @@ def make_matching_messages_query( select_stmt = select_stmt.where(MessageDb.sender.in_(addresses)) if chains: select_stmt = select_stmt.where(MessageDb.chain.in_(chains)) + if message_types: + if len(message_types) == 1: + select_stmt = select_stmt.where(MessageDb.type == message_types[0]) + else: + select_stmt = select_stmt.where(MessageDb.type.in_(message_types)) if message_type: + warnings.warn( + "Warning: `msgType`/`message_type` query parameter is deprecated in favor of `msgTypes`\`message_types` and will be removed in future versions." + ) select_stmt = select_stmt.where(MessageDb.type == message_type) if start_datetime: select_stmt = select_stmt.where(MessageDb.time >= start_datetime) @@ -234,7 +244,6 @@ def refresh_address_stats_mat_view(session: DbSession) -> None: def get_unconfirmed_messages( session: DbSession, limit: int = 100, chain: Optional[Chain] = None ) -> Iterable[MessageDb]: - if chain is None: select_message_confirmations = select(message_confirmations.c.item_hash).where( message_confirmations.c.item_hash == MessageDb.item_hash diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 6d1f409b1..c18ab068d 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -67,7 +67,10 @@ class BaseMessageQueryParams(BaseModel): "-1 means most recent messages first, 1 means older messages first.", ) message_type: Optional[MessageType] = Field( - default=None, alias="msgType", description="Message type." + default=None, alias="msgType", description="[DEPRECATED] Message type." + ) + message_types: Optional[List[MessageType]] = Field( + default=None, alias="msgTypes", description="Accepted types of messages." ) addresses: Optional[List[str]] = Field( default=None, description="Accepted values for the 'sender' field." From 39b600d05a0f2d0aefc33ddbced664820044226b Mon Sep 17 00:00:00 2001 From: Mike Hukiewitz Date: Tue, 6 Jun 2023 15:56:14 +0200 Subject: [PATCH 2/4] commit suggestions --- src/aleph/db/accessors/messages.py | 10 ++-------- src/aleph/web/controllers/messages.py | 4 ++-- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/aleph/db/accessors/messages.py b/src/aleph/db/accessors/messages.py index 887f81bfd..cbccee4af 100644 --- a/src/aleph/db/accessors/messages.py +++ b/src/aleph/db/accessors/messages.py @@ -1,6 +1,5 @@ import datetime as dt import traceback -import warnings from typing import Optional, Sequence, Union, Iterable, Any, Mapping, overload, Tuple from aleph_message.models import ItemHash, Chain, MessageType @@ -90,14 +89,8 @@ def make_matching_messages_query( if chains: select_stmt = select_stmt.where(MessageDb.chain.in_(chains)) if message_types: - if len(message_types) == 1: - select_stmt = select_stmt.where(MessageDb.type == message_types[0]) - else: - select_stmt = select_stmt.where(MessageDb.type.in_(message_types)) + select_stmt = select_stmt.where(MessageDb.type.in_(message_types)) if message_type: - warnings.warn( - "Warning: `msgType`/`message_type` query parameter is deprecated in favor of `msgTypes`\`message_types` and will be removed in future versions." - ) select_stmt = select_stmt.where(MessageDb.type == message_type) if start_datetime: select_stmt = select_stmt.where(MessageDb.time >= start_datetime) @@ -244,6 +237,7 @@ def refresh_address_stats_mat_view(session: DbSession) -> None: def get_unconfirmed_messages( session: DbSession, limit: int = 100, chain: Optional[Chain] = None ) -> Iterable[MessageDb]: + if chain is None: select_message_confirmations = select(message_confirmations.c.item_hash).where( message_confirmations.c.item_hash == MessageDb.item_hash diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index c18ab068d..76f2a58d8 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -67,10 +67,10 @@ class BaseMessageQueryParams(BaseModel): "-1 means most recent messages first, 1 means older messages first.", ) message_type: Optional[MessageType] = Field( - default=None, alias="msgType", description="[DEPRECATED] Message type." + default=None, alias="msgType", description="Message type. Deprecated: use msgTypes instead" ) message_types: Optional[List[MessageType]] = Field( - default=None, alias="msgTypes", description="Accepted types of messages." + default=None, alias="msgTypes", description="Accepted message types." ) addresses: Optional[List[str]] = Field( default=None, description="Accepted values for the 'sender' field." From 48e69507e884c4eaaecebe0618f7d586278df9d7 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Tue, 5 Sep 2023 14:38:49 +0200 Subject: [PATCH 3/4] Remove useless variable --- src/aleph/web/controllers/messages.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 76f2a58d8..2bb5545a3 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -359,7 +359,6 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: except ValidationError as e: raise web.HTTPUnprocessableEntity(body=e.json(indent=4)) - message_filters = query_params.dict(exclude_none=True) history = query_params.history if history: From b16da02305c1dca13662cf1296bb58a8398eec78 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Tue, 5 Sep 2023 14:47:00 +0200 Subject: [PATCH 4/4] fix list issue --- src/aleph/web/controllers/messages.py | 1 + tests/api/test_list_messages.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 2bb5545a3..47d49655a 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -123,6 +123,7 @@ def validate_field_dependencies(cls, values): "content_types", "chains", "channels", + "message_types", "tags", pre=True, ) diff --git a/tests/api/test_list_messages.py b/tests/api/test_list_messages.py index 5d97d4e97..69b4534f1 100644 --- a/tests/api/test_list_messages.py +++ b/tests/api/test_list_messages.py @@ -1,5 +1,6 @@ import datetime as dt import itertools +from collections import defaultdict from typing import Any, Dict, Iterable, List, Optional, Sequence, Union, Tuple import aiohttp @@ -221,6 +222,24 @@ async def test_get_messages_filter_by_tags( assert messages[0]["item_hash"] == amend_message_db.item_hash +@pytest.mark.parametrize("type_field", ("msgType", "msgTypes")) +@pytest.mark.asyncio +async def test_get_by_message_type(fixture_messages, ccn_api_client, type_field: str): + messages_by_type = defaultdict(list) + for message in fixture_messages: + messages_by_type[message["type"]].append(message) + + for message_type, expected_messages in messages_by_type.items(): + response = await ccn_api_client.get( + MESSAGES_URI, params={type_field: message_type} + ) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert set(msg["item_hash"] for msg in messages) == set( + msg["item_hash"] for msg in expected_messages + ) + + @pytest.mark.asyncio async def test_get_messages_filter_by_tags_no_match(fixture_messages, ccn_api_client): """