Skip to content

Commit

Permalink
Problem: the msgType parameter when querying messages only accepts a …
Browse files Browse the repository at this point in the history
…single message type, where other params accept a list of values

Solution: add msgTypes field and deprecate msgType
  • Loading branch information
MHHukiewitz committed May 28, 2023
1 parent 64c97e2 commit 494ad39
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
11 changes: 10 additions & 1 deletion src/aleph/db/accessors/messages.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime as dt
import traceback
import warnings
from typing import Optional, Sequence, Union, Iterable, Any, Mapping, overload, Tuple

from aleph_message.models import ItemHash, Chain, MessageType
Expand Down Expand Up @@ -52,6 +53,7 @@ def make_matching_messages_query(
refs: Optional[Sequence[str]] = None,
chains: Optional[Sequence[Chain]] = None,
message_type: Optional[MessageType] = None,
message_types: Optional[Sequence[MessageType]] = None,
start_date: Optional[Union[float, dt.datetime]] = None,
end_date: Optional[Union[float, dt.datetime]] = None,
content_hashes: Optional[Sequence[ItemHash]] = None,
Expand Down Expand Up @@ -87,7 +89,15 @@ def make_matching_messages_query(
select_stmt = select_stmt.where(MessageDb.sender.in_(addresses))
if chains:
select_stmt = select_stmt.where(MessageDb.chain.in_(chains))
if message_types:
if len(message_types) == 1:
select_stmt = select_stmt.where(MessageDb.type == message_types[0])
else:
select_stmt = select_stmt.where(MessageDb.type.in_(message_types))
if message_type:
warnings.warn(
"Warning: `msgType`/`message_type` query parameter is deprecated in favor of `msgTypes`\`message_types` and will be removed in future versions."
)
select_stmt = select_stmt.where(MessageDb.type == message_type)
if start_datetime:
select_stmt = select_stmt.where(MessageDb.time >= start_datetime)
Expand Down Expand Up @@ -234,7 +244,6 @@ def refresh_address_stats_mat_view(session: DbSession) -> None:
def get_unconfirmed_messages(
session: DbSession, limit: int = 100, chain: Optional[Chain] = None
) -> Iterable[MessageDb]:

if chain is None:
select_message_confirmations = select(message_confirmations.c.item_hash).where(
message_confirmations.c.item_hash == MessageDb.item_hash
Expand Down
9 changes: 6 additions & 3 deletions src/aleph/web/controllers/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from aleph.web.controllers.app_state_getters import (
get_session_factory_from_request,
get_mq_channel_from_request,
get_config_from_request, get_mq_ws_channel_from_request,
get_config_from_request,
get_mq_ws_channel_from_request,
)
from aleph.web.controllers.utils import (
DEFAULT_MESSAGES_PER_PAGE,
Expand Down Expand Up @@ -64,7 +65,10 @@ class BaseMessageQueryParams(BaseModel):
"-1 means most recent messages first, 1 means older messages first.",
)
message_type: Optional[MessageType] = Field(
default=None, alias="msgType", description="Message type."
default=None, alias="msgType", description="[DEPRECATED] Message type."
)
message_types: Optional[List[MessageType]] = Field(
default=None, alias="msgTypes", description="Accepted types of messages."
)
addresses: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'sender' field."
Expand Down Expand Up @@ -239,7 +243,6 @@ async def _send_history_to_ws(
history: int,
message_filters: Dict[str, Any],
) -> None:

with session_factory() as session:
messages = get_matching_messages(
session=session,
Expand Down

0 comments on commit 494ad39

Please sign in to comment.