Skip to content

Commit

Permalink
Fix typing issues with mypy by using a generic type for AlephMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
MHHukiewitz committed Nov 24, 2023
1 parent fba6fc0 commit 2a0f34d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
22 changes: 9 additions & 13 deletions aleph_message/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Literal,
Optional,
Type,
Union,
Union, TypeVar, cast,
)

from pydantic import BaseModel, Extra, Field, validator
Expand Down Expand Up @@ -342,14 +342,10 @@ class InstanceMessage(BaseMessage):
ForgetMessage,
]

AlephMessageType: TypeAlias = Union[
Type[PostMessage],
Type[AggregateMessage],
Type[StoreMessage],
Type[ProgramMessage],
Type[InstanceMessage],
Type[ForgetMessage],
]

T = TypeVar('T', bound=AlephMessage)

AlephMessageType: TypeAlias = Type[T]

message_classes: List[AlephMessageType] = [
PostMessage,
Expand Down Expand Up @@ -391,16 +387,16 @@ def add_item_content_and_hash(message_dict: Dict, inplace: bool = False):

def create_new_message(
message_dict: Dict,
factory: Optional[AlephMessageType] = None,
) -> AlephMessage:
factory: Optional[Type[T]] = None,
) -> T:
"""Create a new message from a dict.
Computes the 'item_content' and 'item_hash' fields.
"""
message_content = add_item_content_and_hash(message_dict)
if factory:
return factory.parse_obj(message_content)
return cast(T, factory.parse_obj(message_content))
else:
return parse_message(message_content)
return cast(T, parse_message(message_content))


def create_message_from_json(
Expand Down
6 changes: 3 additions & 3 deletions aleph_message/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
create_message_from_file,
create_message_from_json,
create_new_message,
parse_message,
parse_message, AlephMessage,
)
from aleph_message.tests.download_messages import MESSAGES_STORAGE_PATH

Expand Down Expand Up @@ -87,7 +87,7 @@ def test_messages_last_page():
if message_dict["item_hash"] in HASHES_TO_IGNORE:
continue

message = parse_message(message_dict)
message: AlephMessage = parse_message(message_dict)
assert message


Expand Down Expand Up @@ -308,7 +308,7 @@ def test_messages_from_disk():
data_dict = json.load(page_fd)
for message_dict in data_dict["messages"]:
try:
message = parse_message(message_dict)
message: AlephMessage = parse_message(message_dict)
assert message
except ValidationError as e:
console.print("-" * 79)
Expand Down

0 comments on commit 2a0f34d

Please sign in to comment.