From 09ff7b7e7477440dd073b7f46c22e322f2e61ac8 Mon Sep 17 00:00:00 2001 From: Alexis Rossfelder Date: Mon, 29 Apr 2024 17:49:44 +0200 Subject: [PATCH] Change the way Serializable classes work Provide a way to test serialization, deserialization, validation and deserialization errors easily. This fixes Avoid repetition in tests for serialize+deserialize tests #64 and makes it easier to add new data types --- changes/273.internal.md | 90 +++++ mcproto/packets/handshaking/handshake.py | 64 +-- mcproto/packets/login/login.py | 220 +++++----- mcproto/packets/status/ping.py | 24 +- mcproto/packets/status/status.py | 35 +- mcproto/protocol/base_io.py | 12 +- mcproto/types/abc.py | 4 +- mcproto/types/chat.py | 34 +- mcproto/types/uuid.py | 9 +- mcproto/utils/abc.py | 71 +++- tests/helpers.py | 180 ++++++++- .../packets/handshaking/test_handshake.py | 103 +---- tests/mcproto/packets/login/test_login.py | 382 +++++------------- tests/mcproto/packets/status/test_ping.py | 46 +-- tests/mcproto/packets/status/test_status.py | 79 ++-- tests/mcproto/types/test_chat.py | 75 ++-- tests/mcproto/types/test_uuid.py | 44 +- tests/mcproto/utils/test_serializable.py | 56 +++ 18 files changed, 799 insertions(+), 729 deletions(-) create mode 100644 changes/273.internal.md create mode 100644 tests/mcproto/utils/test_serializable.py diff --git a/changes/273.internal.md b/changes/273.internal.md new file mode 100644 index 00000000..4096e20f --- /dev/null +++ b/changes/273.internal.md @@ -0,0 +1,90 @@ +- Changed the way `Serializable` classes are handled: + + Here is how a basic `Serializable` class looks like: + + @final + @dataclass + class ToyClass(Serializable): + """Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`.""" + + + # Attributes can be of any type + a: int + b: str + + # dataclasses.field() can be used to specify additional metadata + + def serialize_to(self, buf: Buffer): + """Write the object to a buffer.""" + buf.write_varint(self.a) + buf.write_utf(self.b) + + @classmethod + def deserialize(cls, buf: Buffer) -> ToyClass: + """Deserialize the object from a buffer.""" + a = buf.read_varint() + if a == 0: + raise ZeroDivisionError("a must be non-zero") + b = buf.read_utf() + return cls(a, b) + + def validate(self) -> None: + """Validate the object's attributes.""" + if self.a == 0: + raise ZeroDivisionError("a must be non-zero") + if len(self.b) > 10: + raise ValueError("b must be less than 10 characters") + + + The `Serializable` class must implement the following methods: + + - `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer. + - `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer. + - `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid. + +- Added a test generator for `Serializable` classes: + + The `gen_serializable_test` function generates tests for `Serializable` classes. It takes the following arguments: + + - `context`: The dictionary containing the context in which the generated test class will be placed (e.g. `globals()`). + > Dictionary updates must reflect in the context. This is the case for `globals()` but implementation-specific for `locals()`. + - `cls`: The `Serializable` class to generate tests for. + - `fields`: A list of fields where the test values will be placed. + + > In the example above, the `ToyClass` class has two fields: `a` and `b`. + + - `test_data`: A list of tuples containing either: + - `((field1_value, field2_value, ...), expected_bytes)`: The values of the fields and the expected serialized bytes. This needs to work both ways, i.e. `cls(field1_value, field2_value, ...) == cls.deserialize(expected_bytes).` + - `((field1_value, field2_value, ...), exception)`: The values of the fields and the expected exception when validating the object. + - `(exception, bytes)`: The expected exception when deserializing the bytes and the bytes to deserialize. + + The `gen_serializable_test` function generates a test class with the following tests: + + gen_serializable_test( + context=globals(), + cls=ToyClass, + fields=[("a", int), ("b", str)], + test_data=[ + ((1, "hello"), b"\x01\x05hello"), + ((2, "world"), b"\x02\x05world"), + ((0, "hello"), ZeroDivisionError), + ((1, "hello world"), ValueError), + (ZeroDivisionError, b"\x00"), + (IOError, b"\x01"), + ], + ) + + The generated test class will have the following tests: + + class TestGenToyClass: + def test_serialization(self): + # 2 subtests for the cases 1 and 2 + + def test_deserialization(self): + # 2 subtests for the cases 1 and 2 + + def test_validation(self): + # 2 subtests for the cases 3 and 4 + + def test_exceptions(self): + # 2 subtests for the cases 5 and 6 diff --git a/mcproto/packets/handshaking/handshake.py b/mcproto/packets/handshaking/handshake.py index 46dbc219..4e1da977 100644 --- a/mcproto/packets/handshaking/handshake.py +++ b/mcproto/packets/handshaking/handshake.py @@ -1,13 +1,14 @@ from __future__ import annotations from enum import IntEnum -from typing import ClassVar, final +from typing import ClassVar, cast, final from typing_extensions import Self, override from mcproto.buffer import Buffer from mcproto.packets.packet import GameState, ServerBoundPacket from mcproto.protocol.base_io import StructFormat +from mcproto.utils.abc import dataclass __all__ = [ "NextState", @@ -23,49 +24,38 @@ class NextState(IntEnum): @final +@dataclass class Handshake(ServerBoundPacket): - """Initializes connection between server and client. (Client -> Server).""" + """Initializes connection between server and client. (Client -> Server). + + Initialize the Handshake packet. + + :param protocol_version: Protocol version number to be used. + :param server_address: The host/address the client is connecting to. + :param server_port: The port the client is connecting to. + :param next_state: The next state for the server to move into. + """ PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.HANDSHAKING - __slots__ = ("protocol_version", "server_address", "server_port", "next_state") - - def __init__( - self, - *, - protocol_version: int, - server_address: str, - server_port: int, - next_state: NextState | int, - ): - """Initialize the Handshake packet. - - :param protocol_version: Protocol version number to be used. - :param server_address: The host/address the client is connecting to. - :param server_port: The port the client is connecting to. - :param next_state: The next state for the server to move into. - """ - if not isinstance(next_state, NextState): # next_state is int - rev_lookup = {x.value: x for x in NextState.__members__.values()} - try: - next_state = rev_lookup[next_state] - except KeyError as exc: - raise ValueError("No such next_state.") from exc + # Slots are already managed by the dataclass decorator automatically. + # __slots__ = ("protocol_version", "server_address", "server_port", "next_state") - self.protocol_version = protocol_version - self.server_address = server_address - self.server_port = server_port - self.next_state = next_state + # _ : dataclasses.KW_ONLY # Only available in Python 3.10+ + protocol_version: int + server_address: str + server_port: int + next_state: NextState | int @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: + """Serialize the packet.""" + self.next_state = cast(NextState, self.next_state) # Handled by the validate method buf.write_varint(self.protocol_version) buf.write_utf(self.server_address) buf.write_value(StructFormat.USHORT, self.server_port) buf.write_varint(self.next_state.value) - return buf @override @classmethod @@ -76,3 +66,13 @@ def _deserialize(cls, buf: Buffer, /) -> Self: server_port=buf.read_value(StructFormat.USHORT), next_state=buf.read_varint(), ) + + @override + def validate(self) -> None: + """Validate the packet.""" + if not isinstance(self.next_state, NextState): + rev_lookup = {x.value: x for x in NextState.__members__.values()} + try: + self.next_state = rev_lookup[self.next_state] + except KeyError as exc: + raise ValueError("No such next_state.") from exc diff --git a/mcproto/packets/login/login.py b/mcproto/packets/login/login.py index 782aa505..16f686d8 100644 --- a/mcproto/packets/login/login.py +++ b/mcproto/packets/login/login.py @@ -11,6 +11,7 @@ from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket from mcproto.types.chat import ChatMessage from mcproto.types.uuid import UUID +from mcproto.utils.abc import dataclass __all__ = [ "LoginStart", @@ -25,29 +26,26 @@ @final +@dataclass class LoginStart(ServerBoundPacket): - """Packet from client asking to start login process. (Client -> Server).""" + """Packet from client asking to start login process. (Client -> Server). - __slots__ = ("username", "uuid") + Initialize the LoginStart packet. + + :param username: Username of the client who sent the request. + :param uuid: UUID of the player logging in (if the player doesn't have a UUID, this can be ``None``) + """ PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, *, username: str, uuid: UUID): - """Initialize the LoginStart packet. - - :param username: Username of the client who sent the request. - :param uuid: UUID of the player logging in (if the player doesn't have a UUID, this can be ``None``) - """ - self.username = username - self.uuid = uuid + username: str + uuid: UUID @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_utf(self.username) - buf.extend(self.uuid.serialize()) - return buf + self.uuid.serialize_to(buf) @override @classmethod @@ -58,44 +56,39 @@ def _deserialize(cls, buf: Buffer, /) -> Self: @final +@dataclass class LoginEncryptionRequest(ClientBoundPacket): - """Used by the server to ask the client to encrypt the login process. (Server -> Client).""" + """Used by the server to ask the client to encrypt the login process. (Server -> Client). + + Initialize the LoginEncryptionRequest packet. - __slots__ = ("server_id", "public_key", "verify_token") + :param public_key: Server's public key. + :param verify_token: Sequence of random bytes generated by server for verification. + :param server_id: Empty on minecraft versions 1.7.X and higher (20 random chars pre 1.7). + """ PACKET_ID: ClassVar[int] = 0x01 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, *, server_id: str | None = None, public_key: RSAPublicKey, verify_token: bytes): - """Initialize the LoginEncryptionRequest packet. - - :param server_id: Empty on minecraft versions 1.7.X and higher (20 random chars pre 1.7). - :param public_key: Server's public key. - :param verify_token: Sequence of random bytes generated by server for verification. - """ - if server_id is None: - server_id = " " * 20 - - self.server_id = server_id - self.public_key = public_key - self.verify_token = verify_token + public_key: RSAPublicKey + verify_token: bytes + server_id: str | None = None @override - def serialize(self) -> Buffer: - public_key_raw = self.public_key.public_bytes(encoding=Encoding.DER, format=PublicFormat.SubjectPublicKeyInfo) + def serialize_to(self, buf: Buffer) -> None: + self.server_id = cast(str, self.server_id) - buf = Buffer() + public_key_raw = self.public_key.public_bytes(encoding=Encoding.DER, format=PublicFormat.SubjectPublicKeyInfo) buf.write_utf(self.server_id) buf.write_bytearray(public_key_raw) buf.write_bytearray(self.verify_token) - return buf @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: server_id = buf.read_utf() - public_key_raw = buf.read_bytearray() - verify_token = buf.read_bytearray() + public_key_raw = bytes(buf.read_bytearray()) + verify_token = bytes(buf.read_bytearray()) # Key type is determined by the passed key itself, we know in our case, it will # be an RSA public key, so we explicitly type-cast here. @@ -103,64 +96,65 @@ def _deserialize(cls, buf: Buffer, /) -> Self: return cls(server_id=server_id, public_key=public_key, verify_token=verify_token) + @override + def validate(self) -> None: + """Validate the packet.""" + if self.server_id is None: + self.server_id = " " * 20 + @final +@dataclass class LoginEncryptionResponse(ServerBoundPacket): - """Response from the client to :class:`LoginEncryptionRequest` packet. (Client -> Server).""" + """Response from the client to :class:`LoginEncryptionRequest` packet. (Client -> Server). + + Initialize the LoginEncryptionResponse packet. - __slots__ = ("shared_secret", "verify_token") + :param shared_secret: Shared secret value, encrypted with server's public key. + :param verify_token: Verify token value, encrypted with same public key. + """ PACKET_ID: ClassVar[int] = 0x01 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, *, shared_secret: bytes, verify_token: bytes): - """Initialize the LoginEncryptionResponse packet. - - :param shared_secret: Shared secret value, encrypted with server's public key. - :param verify_token: Verify token value, encrypted with same public key. - """ - self.shared_secret = shared_secret - self.verify_token = verify_token + shared_secret: bytes + verify_token: bytes @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: + """Serialize the packet.""" buf.write_bytearray(self.shared_secret) buf.write_bytearray(self.verify_token) - return buf @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - shared_secret = buf.read_bytearray() - verify_token = buf.read_bytearray() + shared_secret = bytes(buf.read_bytearray()) + verify_token = bytes(buf.read_bytearray()) return cls(shared_secret=shared_secret, verify_token=verify_token) @final +@dataclass class LoginSuccess(ClientBoundPacket): - """Sent by the server to denote a successful login. (Server -> Client).""" + """Sent by the server to denote a successful login. (Server -> Client). - __slots__ = ("uuid", "username") + Initialize the LoginSuccess packet. + + :param uuid: The UUID of the connecting player/client. + :param username: The username of the connecting player/client. + """ PACKET_ID: ClassVar[int] = 0x02 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, uuid: UUID, username: str): - """Initialize the LoginSuccess packet. - - :param uuid: The UUID of the connecting player/client. - :param username: The username of the connecting player/client. - """ - self.uuid = uuid - self.username = username + uuid: UUID + username: str @override - def serialize(self) -> Buffer: - buf = Buffer() - buf.extend(self.uuid.serialize()) + def serialize_to(self, buf: Buffer) -> None: + self.uuid.serialize_to(buf) buf.write_utf(self.username) - return buf @override @classmethod @@ -171,24 +165,23 @@ def _deserialize(cls, buf: Buffer, /) -> Self: @final +@dataclass class LoginDisconnect(ClientBoundPacket): - """Sent by the server to kick a player while in the login state. (Server -> Client).""" + """Sent by the server to kick a player while in the login state. (Server -> Client). + + Initialize the LoginDisconnect packet. - __slots__ = ("reason",) + :param reason: The reason for disconnection (kick). + """ PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, reason: ChatMessage): - """Initialize the LoginDisconnect packet. - - :param reason: The reason for disconnection (kick). - """ - self.reason = reason + reason: ChatMessage @override - def serialize(self) -> Buffer: - return self.reason.serialize() + def serialize_to(self, buf: Buffer) -> None: + self.reason.serialize_to(buf) @override @classmethod @@ -198,101 +191,92 @@ def _deserialize(cls, buf: Buffer, /) -> Self: @final +@dataclass class LoginPluginRequest(ClientBoundPacket): - """Sent by the server to implement a custom handshaking flow. (Server -> Client).""" + """Sent by the server to implement a custom handshaking flow. (Server -> Client). - __slots__ = ("message_id", "channel", "data") + Initialize the LoginPluginRequest. + + :param message_id: Message id, generated by the server, should be unique to the connection. + :param channel: Channel identifier, name of the plugin channel used to send data. + :param data: Data that is to be sent. + """ PACKET_ID: ClassVar[int] = 0x04 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, message_id: int, channel: str, data: bytes): - """Initialize the LoginPluginRequest. - - :param message_id: Message id, generated by the server, should be unique to the connection. - :param channel: Channel identifier, name of the plugin channel used to send data. - :param data: Data that is to be sent. - """ - self.message_id = message_id - self.channel = channel - self.data = data + message_id: int + channel: str + data: bytes @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_varint(self.message_id) buf.write_utf(self.channel) buf.write(self.data) - return buf @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: message_id = buf.read_varint() channel = buf.read_utf() - data = buf.read(buf.remaining) # All of the remaining data in the buffer + data = bytes(buf.read(buf.remaining)) # All of the remaining data in the buffer return cls(message_id, channel, data) @final +@dataclass class LoginPluginResponse(ServerBoundPacket): - """Response to LoginPluginRequest from client. (Client -> Server).""" + """Response to LoginPluginRequest from client. (Client -> Server). - __slots__ = ("message_id", "data") + Initialize the LoginPluginRequest packet. + + :param message_id: Message id, generated by the server, should be unique to the connection. + :param data: Optional response data, present if client understood request. + """ PACKET_ID: ClassVar[int] = 0x02 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, message_id: int, data: bytes | None): - """Initialize the LoginPluginRequest packet. - - :param message_id: Message id, generated by the server, should be unique to the connection. - :param data: Optional response data, present if client understood request. - """ - self.message_id = message_id - self.data = data + message_id: int + data: bytes | None @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_varint(self.message_id) buf.write_optional(self.data, buf.write) - return buf @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: message_id = buf.read_varint() - data = buf.read_optional(lambda: buf.read(buf.remaining)) + data = buf.read_optional(lambda: bytes(buf.read(buf.remaining))) return cls(message_id, data) @final +@dataclass class LoginSetCompression(ClientBoundPacket): """Sent by the server to specify whether to use compression on future packets or not (Server -> Client). - Note that this packet is optional, and if not set, the compression will not be enabled at all. - """ + Initialize the LoginSetCompression packet. + - __slots__ = ("threshold",) + :param threshold: + Maximum size of a packet before it is compressed. All packets smaller than this will remain uncompressed. + To disable compression completely, threshold can be set to -1. + + :note: This packet is optional, and if not set, the compression will not be enabled at all. + """ PACKET_ID: ClassVar[int] = 0x03 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, threshold: int): - """Initialize the LoginSetCompression packet. - - :param threshold: - Maximum size of a packet before it is compressed. All packets smaller than this will remain uncompressed. - To disable compression completely, threshold can be set to -1. - """ - self.threshold = threshold + threshold: int @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_varint(self.threshold) - return buf @override @classmethod diff --git a/mcproto/packets/status/ping.py b/mcproto/packets/status/ping.py index 6ab4e355..29157e5a 100644 --- a/mcproto/packets/status/ping.py +++ b/mcproto/packets/status/ping.py @@ -7,33 +7,31 @@ from mcproto.buffer import Buffer from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket from mcproto.protocol.base_io import StructFormat +from mcproto.utils.abc import dataclass __all__ = ["PingPong"] @final +@dataclass class PingPong(ClientBoundPacket, ServerBoundPacket): - """Ping request/Pong response (Server <-> Client).""" + """Ping request/Pong response (Server <-> Client). - __slots__ = ("payload",) + Initialize the PingPong packet. + + :param payload: + Random number to test out the connection. Ideally, this number should be quite big, + however it does need to fit within the limit of a signed long long (-2 ** 63 to 2 ** 63 - 1). + """ PACKET_ID: ClassVar[int] = 0x01 GAME_STATE: ClassVar[GameState] = GameState.STATUS - def __init__(self, payload: int): - """Initialize the PingPong packet. - - :param payload: - Random number to test out the connection. Ideally, this number should be quite big, - however it does need to fit within the limit of a signed long long (-2 ** 63 to 2 ** 63 - 1). - """ - self.payload = payload + payload: int @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_value(StructFormat.LONGLONG, self.payload) - return buf @override @classmethod diff --git a/mcproto/packets/status/status.py b/mcproto/packets/status/status.py index 4d1ad173..f9a5884c 100644 --- a/mcproto/packets/status/status.py +++ b/mcproto/packets/status/status.py @@ -7,22 +7,22 @@ from mcproto.buffer import Buffer from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket +from mcproto.utils.abc import dataclass __all__ = ["StatusRequest", "StatusResponse"] @final +@dataclass class StatusRequest(ServerBoundPacket): """Request from the client to get information on the server. (Client -> Server).""" - __slots__ = () - PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.STATUS @override - def serialize(self) -> Buffer: # pragma: no cover, nothing to test here. - return Buffer() + def serialize_to(self, buf: Buffer) -> None: + return # pragma: no cover, nothing to test here. @override @classmethod @@ -31,27 +31,24 @@ def _deserialize(cls, buf: Buffer, /) -> Self: # pragma: no cover, nothing to t @final +@dataclass class StatusResponse(ClientBoundPacket): - """Response from the server to requesting client with status data information. (Server -> Client).""" + """Response from the server to requesting client with status data information. (Server -> Client). + + Initialize the StatusResponse packet. - __slots__ = ("data",) + :param data: JSON response data sent back to the client. + """ PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.STATUS - def __init__(self, data: dict[str, Any]): - """Initialize the StatusResponse packet. - - :param data: JSON response data sent back to the client. - """ - self.data = data + data: dict[str, Any] # JSON response data sent back to the client. @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: s = json.dumps(self.data) buf.write_utf(s) - return buf @override @classmethod @@ -59,3 +56,11 @@ def _deserialize(cls, buf: Buffer, /) -> Self: s = buf.read_utf() data_ = json.loads(s) return cls(data_) + + @override + def validate(self) -> None: + # Ensure the data is serializable to JSON + try: + json.dumps(self.data) + except TypeError as exc: + raise ValueError("Data is not serializable to JSON.") from exc diff --git a/mcproto/protocol/base_io.py b/mcproto/protocol/base_io.py index b2b6fc61..f20948b6 100644 --- a/mcproto/protocol/base_io.py +++ b/mcproto/protocol/base_io.py @@ -155,9 +155,9 @@ async def write_bytearray(self, data: bytes, /) -> None: async def write_ascii(self, value: str, /) -> None: """Write ISO-8859-1 encoded string, with NULL (0x00) at the end to indicate string end.""" - data = bytearray(value, "ISO-8859-1") + data = bytes(value, "ISO-8859-1") await self.write(data) - await self.write(bytearray.fromhex("00")) + await self.write(bytes([0])) async def write_utf(self, value: str, /) -> None: """Write a UTF-8 encoded string, prefixed with a varint of it's size (in bytes). @@ -174,7 +174,7 @@ async def write_utf(self, value: str, /) -> None: if len(value) > 32767: raise ValueError("Maximum character limit for writing strings is 32767 characters.") - data = bytearray(value, "utf-8") + data = bytes(value, "utf-8") await self.write_varint(len(data)) await self.write(data) @@ -272,9 +272,9 @@ def write_bytearray(self, data: bytes, /) -> None: def write_ascii(self, value: str, /) -> None: """Write ISO-8859-1 encoded string, with NULL (0x00) at the end to indicate string end.""" - data = bytearray(value, "ISO-8859-1") + data = bytes(value, "ISO-8859-1") self.write(data) - self.write(bytearray.fromhex("00")) + self.write(bytes([0])) def write_utf(self, value: str, /) -> None: """Write a UTF-8 encoded string, prefixed with a varint of it's size (in bytes). @@ -291,7 +291,7 @@ def write_utf(self, value: str, /) -> None: if len(value) > 32767: raise ValueError("Maximum character limit for writing strings is 32767 characters.") - data = bytearray(value, "utf-8") + data = bytes(value, "utf-8") self.write_varint(len(data)) self.write(data) diff --git a/mcproto/types/abc.py b/mcproto/types/abc.py index 03614720..41c998c4 100644 --- a/mcproto/types/abc.py +++ b/mcproto/types/abc.py @@ -1,8 +1,8 @@ from __future__ import annotations -from mcproto.utils.abc import Serializable +from mcproto.utils.abc import Serializable, dataclass -__all__ = ["MCType"] +__all__ = ["MCType", "dataclass"] # That way we can import it from mcproto.types.abc class MCType(Serializable): diff --git a/mcproto/types/chat.py b/mcproto/types/chat.py index 8b915aa2..16cc8265 100644 --- a/mcproto/types/chat.py +++ b/mcproto/types/chat.py @@ -6,7 +6,7 @@ from typing_extensions import Self, TypeAlias, override from mcproto.buffer import Buffer -from mcproto.types.abc import MCType +from mcproto.types.abc import MCType, dataclass __all__ = [ "ChatMessage", @@ -33,14 +33,12 @@ class RawChatMessageDict(TypedDict, total=False): RawChatMessage: TypeAlias = Union[RawChatMessageDict, "list[RawChatMessageDict]", str] +@dataclass @final class ChatMessage(MCType): """Minecraft chat message representation.""" - __slots__ = ("raw",) - - def __init__(self, raw: RawChatMessage): - self.raw = raw + raw: RawChatMessage def as_dict(self) -> RawChatMessageDict: """Convert received ``raw`` into a stadard :class:`dict` form.""" @@ -50,8 +48,10 @@ def as_dict(self) -> RawChatMessageDict: return RawChatMessageDict(text=self.raw) if isinstance(self.raw, dict): # pyright: ignore[reportUnnecessaryIsInstance] return self.raw - # pragma: no cover - raise TypeError(f"Found unexpected type ({self.raw.__class__!r}) ({self.raw!r}) in `raw` attribute") + + raise TypeError( # pragma: no cover + f"Found unexpected type ({self.raw.__class__!r}) ({self.raw!r}) in `raw` attribute" + ) @override def __eq__(self, other: object) -> bool: @@ -67,11 +67,9 @@ def __eq__(self, other: object) -> bool: return self.raw == other.raw @override - def serialize(self) -> Buffer: + def serialize_to(self, buf: Buffer) -> None: txt = json.dumps(self.raw) - buf = Buffer() buf.write_utf(txt) - return buf @override @classmethod @@ -79,3 +77,19 @@ def deserialize(cls, buf: Buffer, /) -> Self: txt = buf.read_utf() dct = json.loads(txt) return cls(dct) + + @override + def validate(self) -> None: + if not isinstance(self.raw, (dict, list, str)): # type: ignore[unreachable] + raise TypeError(f"Expected `raw` to be a dict, list or str, got {self.raw!r} instead") + if isinstance(self.raw, dict): # We want to keep it this way for readability + if "text" not in self.raw and "extra" not in self.raw: + raise AttributeError("Expected `raw` to have either 'text' or 'extra' key, got neither") + if isinstance(self.raw, list): + for elem in self.raw: + if not isinstance(elem, dict): # type: ignore[unreachable] + raise TypeError(f"Expected `raw` to be a list of dicts, got {elem!r} instead") + if "text" not in elem and "extra" not in elem: + raise AttributeError( + "Expected each element in `raw` to have either 'text' or 'extra' key, got neither" + ) diff --git a/mcproto/types/uuid.py b/mcproto/types/uuid.py index 97fddde8..500710b5 100644 --- a/mcproto/types/uuid.py +++ b/mcproto/types/uuid.py @@ -12,7 +12,7 @@ @final -class UUID(MCType, uuid.UUID): +class UUID(uuid.UUID, MCType): """Minecraft UUID type. In order to support potential future changes in protocol version, and implement McType, @@ -22,12 +22,11 @@ class UUID(MCType, uuid.UUID): __slots__ = () @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write(self.bytes) - return buf @override @classmethod def deserialize(cls, buf: Buffer, /) -> Self: - return cls(bytes=bytes(buf.read(16))) + data = bytes(buf.read(16)) + return cls(bytes=data) diff --git a/mcproto/utils/abc.py b/mcproto/utils/abc.py index c2873494..5c7dcc64 100644 --- a/mcproto/utils/abc.py +++ b/mcproto/utils/abc.py @@ -1,20 +1,31 @@ from __future__ import annotations -from abc import ABC, abstractmethod +import sys +from abc import ABC, ABCMeta, abstractmethod from collections.abc import Sequence -from typing import ClassVar +from dataclasses import dataclass as _dataclass +from functools import partial +from typing import Any, ClassVar, TYPE_CHECKING from typing_extensions import Self from mcproto.buffer import Buffer -__all__ = ["RequiredParamsABCMixin", "Serializable"] +__all__ = ["RequiredParamsABCMixin", "Serializable", "dataclass"] + +if TYPE_CHECKING: + dataclass = _dataclass # The type checker needs + +if sys.version_info >= (3, 10): + dataclass = partial(_dataclass, slots=True) +else: + dataclass = _dataclass class RequiredParamsABCMixin: """Mixin class to ABCs that require certain attributes to be set in order to allow initialization. - This class performs a similar check to what :class:`~abc.ABC` already des with abstractmethods, + This class performs a similar check to what :class:`~abc.ABC` already does with abstractmethods, but for class variables. The required class variable names are set with :attr:`._REQUIRED_CLASS_VARS` class variable, which itself is automatically required. @@ -37,7 +48,7 @@ class vars which should be defined on given class directly. That means inheritan _REQUIRRED_CLASS_VARS: ClassVar[Sequence[str]] _REQUIRED_CLASS_VARS_NO_MRO: ClassVar[Sequence[str]] - def __new__(cls: type[Self], *a, **kw) -> Self: + def __new__(cls: type[Self], *a: Any, **kw: Any) -> Self: """Enforce required parameters being set for each instance of the concrete classes.""" _err_msg = f"Can't instantiate abstract {cls.__name__} class without defining " + "{!r} classvar" @@ -63,18 +74,60 @@ def __new__(cls: type[Self], *a, **kw) -> Self: return super().__new__(cls) -class Serializable(ABC): - """Base class for any type that should be (de)serializable into/from :class:`~mcproto.Buffer` data.""" +class _MetaDataclass(ABCMeta): + def __new__( + cls: type[_MetaDataclass], + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> Any: # Create the class using the super() method to ensure it is correctly formed as an ABC + new_class = super().__new__(cls, name, bases, namespace, **kwargs) + + # Check if the dataclass is already defined, if not, create it + if not hasattr(new_class, "__dataclass_fields__"): + new_class = dataclass(new_class) + + return new_class + + +class Serializable(ABC): # , metaclass=_MetaDataclass): + """Base class for any type that should be (de)serializable into/from :class:`~mcproto.Buffer` data. + + Any class that inherits from this class and adds parameters should use the :func:`~mcproto.utils.abc.dataclass` + decorator. + """ __slots__ = () - @abstractmethod + def __post_init__(self) -> None: + """Run the validation method after the object is initialized.""" + self.validate() + def serialize(self) -> Buffer: """Represent the object as a :class:`~mcproto.Buffer` (transmittable sequence of bytes).""" + self.validate() + buf = Buffer() + self.serialize_to(buf) + return buf + + @abstractmethod + def serialize_to(self, buf: Buffer, /) -> None: + """Write the object to a :class:`~mcproto.Buffer`.""" raise NotImplementedError + def validate(self) -> None: + """Validate the object's attributes, raising an exception if they are invalid. + + This will be called at the end of the object's initialization, and before serialization. + Use cast() in serialize_to() if your validation asserts that a value is of a certain type. + + By default, this method does nothing. Override it in your subclass to add validation logic. + """ + return + @classmethod @abstractmethod def deserialize(cls, buf: Buffer, /) -> Self: - """Construct the object from a :class:`~mcproto.Buffer` (transmitable sequence of bytes).""" + """Construct the object from a :class:`~mcproto.Buffer` (transmittable sequence of bytes).""" raise NotImplementedError diff --git a/tests/helpers.py b/tests/helpers.py index 44de2032..6e009b3e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,14 +4,20 @@ import inspect import unittest.mock from collections.abc import Callable, Coroutine -from typing import Any, Generic, TypeVar +from typing import Any, Dict, Generic, Tuple, TypeVar, cast +import pytest from typing_extensions import ParamSpec, override +from mcproto.buffer import Buffer +from mcproto.utils.abc import Serializable + T = TypeVar("T") P = ParamSpec("P") T_Mock = TypeVar("T_Mock", bound=unittest.mock.Mock) +__all__ = ["synchronize", "SynchronizedMixin", "UnpropagatingMockMixin", "CustomMockMixin", "gen_serializable_test"] + def synchronize(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]: """Take an asynchronous function, and return a synchronous alternative. @@ -160,3 +166,175 @@ def __init__(self, **kwargs): if "spec_set" in kwargs: self.spec_set = kwargs.pop("spec_set") super().__init__(spec_set=self.spec_set, **kwargs) # type: ignore # Mixin class, this __init__ is valid + + +def gen_serializable_test( + context: dict[str, Any], + cls: type[Serializable], + fields: list[tuple[str, type]], + test_data: list[ + tuple[tuple[Any, ...], bytes] | tuple[tuple[Any, ...], type[Exception]] | tuple[type[Exception], bytes] + ], +): + """Generate tests for a serializable class. + + This function generates tests for the serialization, deserialization, validation, and deserialization error + handling + + :param context: The context to add the test functions to. This is usually `globals()`. + :param cls: The serializable class to test. + :param fields: A list of tuples containing the field names and types of the serializable class. + :param test_data: A list of test data. Each element is a tuple containing either: + - A tuple of parameters to pass to the serializable class constructor and the expected bytes after + serialization + - A tuple of parameters to pass to the serializable class constructor and the expected exception during + validation + - An exception to expect during deserialization and the bytes to deserialize + + Example usage: + ```python + @final + @dataclass + class ToyClass(Serializable): + a: int + b: str + + def serialize_to(self, buf: Buffer): + buf.write_varint(self.a) + buf.write_utf(self.b) + + @classmethod + def deserialize(cls, buf: Buffer) -> "ToyClass": + a = buf.read_varint() + b = buf.read_utf() + if len(b) > 10: + raise ValueError("b must be less than 10 characters") + return cls(a, b) + + def validate(self) -> None: + if self.a == 0: + raise ZeroDivisionError("a must be non-zero") + + gen_serializable_test( + context=globals(), + cls=ToyClass, + fields=[("a", int), ("b", str)], + test_data=[ + ((1, "hello"), b"\x01\x05hello"), + ((2, "world"), b"\x02\x05world"), + ((0, "hello"), ZeroDivisionError), + (IOError, b"\x01"), # Not enough data to deserialize + (ValueError, b"\x01\x0bhello world"), # 0b = 11 is too long + ], + ) + ``` + This will add 1 class test with 4 test functions containing the tests for serialization, deserialization, + validation, and deserialization error handling + + + """ + # Separate the test data into parameters for each test function + # This holds the parameters for the serialization and deserialization tests + parameters: list[tuple[dict[str, Any], bytes]] = [] + + # This holds the parameters for the validation tests + validation_fail: list[tuple[dict[str, Any], type[Exception]]] = [] + + # This holds the parameters for the deserialization error tests + deserialization_fail: list[tuple[bytes, type[Exception]]] = [] + + # kwargs = dict(zip([f[0] for f in fields], data_or_exc)) + for data_or_exc, expected_bytes_or_exc in test_data: + if isinstance(data_or_exc, tuple) and isinstance(expected_bytes_or_exc, bytes): + kwargs = dict(zip([f[0] for f in fields], data_or_exc)) + parameters.append((kwargs, expected_bytes_or_exc)) + elif isinstance(data_or_exc, type) and isinstance(expected_bytes_or_exc, bytes): + deserialization_fail.append((expected_bytes_or_exc, data_or_exc)) + else: + data = cast(Tuple[Any, ...], data_or_exc) + exception = cast(type[Exception], expected_bytes_or_exc) + kwargs = dict(zip([f[0] for f in fields], data)) + validation_fail.append((kwargs, exception)) + + def generate_name(param: dict[str, Any] | bytes, i: int) -> str: + """Generate a name for the test case.""" + length = 30 + result = f"{i:02d}] : " # the first [ is added by pytest + if isinstance(param, bytes): + try: + result += str(param[:length], "utf-8") + "..." if len(param) > (length + 3) else param.decode("utf-8") + except UnicodeDecodeError: + result += repr(param[:length]) + "..." if len(param) > (length + 3) else repr(param) + else: + param = cast(Dict[str, Any], param) + begin = ", ".join(f"{k}={v}" for k, v in param.items()) + result += begin[:length] + "..." if len(begin) > (length + 3) else begin + result = result.replace("\n", "\\n").replace("\r", "\\r") + result += f" [{cls.__name__}" # the other [ is added by pytest + return result + + class TestClass: + """Test class for the generated tests.""" + + @pytest.mark.parametrize( + ("kwargs", "expected_bytes"), + parameters, + ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(parameters)), + ) + def test_serialization(self, kwargs: dict[str, Any], expected_bytes: bytes): + """Test serialization of the object.""" + obj = cls(**kwargs) + serialized_bytes = obj.serialize() + assert serialized_bytes == expected_bytes + + @pytest.mark.parametrize( + ("kwargs", "expected_bytes"), + parameters, + ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(parameters)), + ) + def test_deserialization(self, kwargs: dict[str, Any], expected_bytes: bytes): + """Test deserialization of the object.""" + buf = Buffer(expected_bytes) + obj = cls.deserialize(buf) + assert cls(**kwargs) == obj, f"{cls.__name__}({kwargs}) != {obj}" + assert buf.remaining == 0, f"Buffer has {buf.remaining} bytes remaining" + + @pytest.mark.parametrize( + ("kwargs", "exc"), + validation_fail, + ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(validation_fail)), + ) + def test_validation(self, kwargs: dict[str, Any], exc: type[Exception]): + """Test validation of the object.""" + with pytest.raises(exc): + cls(**kwargs) + + @pytest.mark.parametrize( + ("content", "exc"), + deserialization_fail, + ids=tuple(generate_name(content, i) for i, (content, _) in enumerate(deserialization_fail)), + ) + def test_deserialization_error(self, content: bytes, exc: type[Exception]): + """Test deserialization error handling.""" + buf = Buffer(content) + with pytest.raises(exc): + cls.deserialize(buf) + + if len(parameters) == 0: + # If there are no serialization tests, remove them + del TestClass.test_serialization + del TestClass.test_deserialization + + if len(validation_fail) == 0: + # If there are no validation tests, remove them + del TestClass.test_validation + + if len(deserialization_fail) == 0: + # If there are no deserialization error tests, remove them + del TestClass.test_deserialization_error + + # Set the names of the class + TestClass.__name__ = f"TestGen{cls.__name__}" + + # Add the test functions to the global context + context[TestClass.__name__] = TestClass diff --git a/tests/mcproto/packets/handshaking/test_handshake.py b/tests/mcproto/packets/handshaking/test_handshake.py index 42eb395e..0e34952f 100644 --- a/tests/mcproto/packets/handshaking/test_handshake.py +++ b/tests/mcproto/packets/handshaking/test_handshake.py @@ -1,101 +1,38 @@ from __future__ import annotations -from typing import Any - -import pytest - -from mcproto.buffer import Buffer from mcproto.packets.handshaking.handshake import Handshake, NextState - - -@pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"protocol_version": 757, "server_address": "mc.aircs.racing", "server_port": 25565, "next_state": 2}, - bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd02"), - ), - ( - {"protocol_version": 757, "server_address": "mc.aircs.racing", "server_port": 25565, "next_state": 1}, - bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd01"), - ), - ( - { - "protocol_version": 757, - "server_address": "hypixel.net", - "server_port": 25565, - "next_state": NextState.LOGIN, - }, - bytes.fromhex("f5050b6879706978656c2e6e657463dd02"), - ), - ( - { - "protocol_version": 757, - "server_address": "hypixel.net", - "server_port": 25565, - "next_state": NextState.STATUS, - }, - bytes.fromhex("f5050b6879706978656c2e6e657463dd01"), - ), +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=Handshake, + fields=[ + ("protocol_version", int), + ("server_address", str), + ("server_port", int), + ("next_state", NextState), ], -) -def test_serialize(kwargs: dict[str, Any], expected_bytes: list[int]): - """Test serialization of Handshake packet.""" - handshake = Handshake(**kwargs) - assert handshake.serialize().flush() == bytearray(expected_bytes) - - -@pytest.mark.parametrize( - ("read_bytes", "expected_out"), - [ + test_data=[ ( + (757, "mc.aircs.racing", 25565, NextState.LOGIN), bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd02"), - { - "protocol_version": 757, - "server_address": "mc.aircs.racing", - "server_port": 25565, - "next_state": NextState.LOGIN, - }, ), ( + (757, "mc.aircs.racing", 25565, NextState.STATUS), bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd01"), - { - "protocol_version": 757, - "server_address": "mc.aircs.racing", - "server_port": 25565, - "next_state": NextState.STATUS, - }, ), ( + (757, "hypixel.net", 25565, NextState.LOGIN), bytes.fromhex("f5050b6879706978656c2e6e657463dd02"), - { - "protocol_version": 757, - "server_address": "hypixel.net", - "server_port": 25565, - "next_state": NextState.LOGIN, - }, ), ( + (757, "hypixel.net", 25565, NextState.STATUS), bytes.fromhex("f5050b6879706978656c2e6e657463dd01"), - { - "protocol_version": 757, - "server_address": "hypixel.net", - "server_port": 25565, - "next_state": NextState.STATUS, - }, ), + # Invalid next state + ((757, "localhost", 25565, 3), ValueError), + ((757, "localhost", 25565, 4), ValueError), + ((757, "localhost", 25565, 5), ValueError), + ((757, "localhost", 25565, 6), ValueError), ], ) -def test_deserialize(read_bytes: list[int], expected_out: dict[str, Any]): - """Test deserialization of Handshake packet.""" - handshake = Handshake.deserialize(Buffer(read_bytes)) - - for i, v in expected_out.items(): - assert getattr(handshake, i) == v - - -@pytest.mark.parametrize(("state"), [3, 4, 5, 6]) -def test_invalid_state(state): - """Test initialization of Handshake packet with invalid next state raises :exc:`ValueError`.""" - with pytest.raises(ValueError): - Handshake(protocol_version=757, server_address="localhost", server_port=25565, next_state=state) diff --git a/tests/mcproto/packets/login/test_login.py b/tests/mcproto/packets/login/test_login.py index 23c21cfc..01a8d45d 100644 --- a/tests/mcproto/packets/login/test_login.py +++ b/tests/mcproto/packets/login/test_login.py @@ -1,10 +1,5 @@ from __future__ import annotations -from typing import Any - -import pytest - -from mcproto.buffer import Buffer from mcproto.packets.login.login import ( LoginDisconnect, LoginEncryptionRequest, @@ -17,286 +12,129 @@ ) from mcproto.types.chat import ChatMessage from mcproto.types.uuid import UUID +from tests.helpers import gen_serializable_test from tests.mcproto.test_encryption import RSA_PUBLIC_KEY +# LoginStart +gen_serializable_test( + context=globals(), + cls=LoginStart, + fields=[("username", str), ("uuid", UUID)], + test_data=[ + ( + ("ItsDrike", UUID("f70b4a42c9a04ffb92a31390c128a1b2")), + bytes.fromhex("084974734472696b65f70b4a42c9a04ffb92a31390c128a1b2"), + ), + ( + ("foobar1", UUID("7a82476416fc4e8b8686a99c775db7d3")), + bytes.fromhex("07666f6f626172317a82476416fc4e8b8686a99c775db7d3"), + ), + ], +) -class TestLoginStart: - """Collection of tests for the LoginStart packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"username": "ItsDrike", "uuid": UUID("f70b4a42c9a04ffb92a31390c128a1b2")}, - bytes.fromhex("084974734472696b65f70b4a42c9a04ffb92a31390c128a1b2"), - ), - ( - {"username": "foobar1", "uuid": UUID("7a82476416fc4e8b8686a99c775db7d3")}, - bytes.fromhex("07666f6f626172317a82476416fc4e8b8686a99c775db7d3"), - ), - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginStart packet.""" - packet = LoginStart(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("084974734472696b65f70b4a42c9a04ffb92a31390c128a1b2"), - {"username": "ItsDrike", "uuid": UUID("f70b4a42c9a04ffb92a31390c128a1b2")}, - ), - ( - bytes.fromhex("07666f6f626172317a82476416fc4e8b8686a99c775db7d3"), - {"username": "foobar1", "uuid": UUID("7a82476416fc4e8b8686a99c775db7d3")}, - ), - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginStart packet.""" - packet = LoginStart.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val - - -class TestLoginEncryptionRequest: - """Collection of tests for the LoginEncryptionRequest packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"public_key": RSA_PUBLIC_KEY, "verify_token": bytes.fromhex("9bd416ef"), "server_id": "a" * 20}, - bytes.fromhex( - "146161616161616161616161616161616161616161a20130819f300d06092a864886f70d010101050003818d003081890" - "2818100cb515109911ea3e4740d8a17a7ccd9cf226c83c7615e4a5505cd124571ee210a4ba26c7c42e15f51fcb7fa90dc" - "e6f83ebe0e163817c7d9fb1af7d981e90da2cc06ea59d01ff9fbb76b1803a0fe5af4a2c75145d89eb03e6a4aae21d2e7d" - "4c3938a298da575e12e0ae178d61a69bc0ea0b381790f182d9dba715bfb503c99d92b0203010001049bd416ef" - ), - ), - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginEncryptionRequest packet.""" - packet = LoginEncryptionRequest(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex( - "146161616161616161616161616161616161616161a20130819f300d06092a864886f70d010101050003818d003081890" - "2818100cb515109911ea3e4740d8a17a7ccd9cf226c83c7615e4a5505cd124571ee210a4ba26c7c42e15f51fcb7fa90dc" - "e6f83ebe0e163817c7d9fb1af7d981e90da2cc06ea59d01ff9fbb76b1803a0fe5af4a2c75145d89eb03e6a4aae21d2e7d" - "4c3938a298da575e12e0ae178d61a69bc0ea0b381790f182d9dba715bfb503c99d92b0203010001049bd416ef" - ), - {"public_key": RSA_PUBLIC_KEY, "verify_token": bytes.fromhex("9bd416ef"), "server_id": "a" * 20}, +# LoginEncryptionRequest +gen_serializable_test( + context=globals(), + cls=LoginEncryptionRequest, + fields=[("server_id", str), ("public_key", bytes), ("verify_token", bytes)], + test_data=[ + ( + ("a" * 20, RSA_PUBLIC_KEY, bytes.fromhex("9bd416ef")), + bytes.fromhex( + "146161616161616161616161616161616161616161a20130819f300d06092a864886f70d010101050003818d003081890" + "2818100cb515109911ea3e4740d8a17a7ccd9cf226c83c7615e4a5505cd124571ee210a4ba26c7c42e15f51fcb7fa90dc" + "e6f83ebe0e163817c7d9fb1af7d981e90da2cc06ea59d01ff9fbb76b1803a0fe5af4a2c75145d89eb03e6a4aae21d2e7d" + "4c3938a298da575e12e0ae178d61a69bc0ea0b381790f182d9dba715bfb503c99d92b0203010001049bd416ef" ), - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginEncryptionRequest packet.""" - packet = LoginEncryptionRequest.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val - - -class TestLoginEncryptionResponse: - """Collection of tests for the LoginEncryptionResponse packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"shared_secret": b"I'm shared", "verify_token": b"Token"}, - bytes.fromhex("0a49276d2073686172656405546f6b656e"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginEncryptionRespones packet.""" - packet = LoginEncryptionResponse(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("0a49276d2073686172656405546f6b656e"), - {"shared_secret": b"I'm shared", "verify_token": b"Token"}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginEncryptionRespones packet.""" - packet = LoginEncryptionResponse.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val - - -class TestLoginSuccess: - """Collection of tests for the LoginSuccess packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"uuid": UUID("f70b4a42c9a04ffb92a31390c128a1b2"), "username": "Mario"}, - bytes.fromhex("f70b4a42c9a04ffb92a31390c128a1b2054d6172696f"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginSuccess packet.""" - packet = LoginSuccess(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("f70b4a42c9a04ffb92a31390c128a1b2054d6172696f"), - {"uuid": UUID("f70b4a42c9a04ffb92a31390c128a1b2"), "username": "Mario"}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginSuccess packet.""" - packet = LoginSuccess.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val - - -class TestLoginDisconnect: - """Collection of tests for the LoginDisconnect packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"reason": ChatMessage("You are banned.")}, - bytes.fromhex("1122596f75206172652062616e6e65642e22"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginDisconnect packet.""" - packet = LoginDisconnect(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("1122596f75206172652062616e6e65642e22"), - {"reason": ChatMessage("You are banned.")}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginDisconnect packet.""" - packet = LoginDisconnect.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val + ), + (ValueError, bytes.fromhex("14")), + ], +) -class TestLoginPluginRequest: - """Collection of tests for the LoginPluginRequest packet.""" +def test_login_encryption_request_noid(): + """Test LoginEncryptionRequest without server_id.""" + packet = LoginEncryptionRequest(server_id=None, public_key=RSA_PUBLIC_KEY, verify_token=bytes.fromhex("9bd416ef")) + assert packet.server_id == " " * 20 # None is converted to an empty server id - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"message_id": 0, "channel": "xyz", "data": b"Hello"}, - bytes.fromhex("000378797a48656c6c6f"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginPluginRequest packet.""" - packet = LoginPluginRequest(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("000378797a48656c6c6f"), - {"message_id": 0, "channel": "xyz", "data": b"Hello"}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test serialization of LoginPluginRequest packet.""" - packet = LoginPluginRequest.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val +# TestLoginEncryptionResponse +gen_serializable_test( + context=globals(), + cls=LoginEncryptionResponse, + fields=[("shared_secret", bytes), ("verify_token", bytes)], + test_data=[ + ( + (b"I'm shared", b"Token"), + bytes.fromhex("0a49276d2073686172656405546f6b656e"), + ), + ], +) -class TestLoginPluginResponse: - """Collection of tests for the LoginPluginResponse packet.""" +# LoginSuccess +gen_serializable_test( + context=globals(), + cls=LoginSuccess, + fields=[("uuid", UUID), ("username", str)], + test_data=[ + ( + (UUID("f70b4a42c9a04ffb92a31390c128a1b2"), "Mario"), + bytes.fromhex("f70b4a42c9a04ffb92a31390c128a1b2054d6172696f"), + ), + ], +) - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"message_id": 0, "data": b"Hello"}, - bytes.fromhex("000148656c6c6f"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginPluginResponse packet.""" - packet = LoginPluginResponse(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) +# LoginDisconnect +gen_serializable_test( + context=globals(), + cls=LoginDisconnect, + fields=[("reason", ChatMessage)], + test_data=[ + ( + (ChatMessage("You are banned."),), + bytes.fromhex("1122596f75206172652062616e6e65642e22"), + ), + ], +) - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("000148656c6c6f"), - {"message_id": 0, "data": b"Hello"}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginPluginResponse packet.""" - packet = LoginPluginResponse.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val +# LoginPluginRequest +gen_serializable_test( + context=globals(), + cls=LoginPluginRequest, + fields=[("message_id", int), ("channel", str), ("data", bytes)], + test_data=[ + ( + (0, "xyz", b"Hello"), + bytes.fromhex("000378797a48656c6c6f"), + ), + ], +) -class TestLoginSetCompression: - """Collection of tests for the LoginSetCompression packet.""" - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"threshold": 2}, - bytes.fromhex("02"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginSetCompression packet.""" - packet = LoginSetCompression(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) +# LoginPluginResponse +gen_serializable_test( + context=globals(), + cls=LoginPluginResponse, + fields=[("message_id", int), ("data", bytes)], + test_data=[ + ( + (0, b"Hello"), + bytes.fromhex("000148656c6c6f"), + ), + ], +) - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("02"), - {"threshold": 2}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginSetCompression packet.""" - packet = LoginSetCompression.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val +# LoginSetCompression +gen_serializable_test( + context=globals(), + cls=LoginSetCompression, + fields=[("threshold", int)], + test_data=[ + ( + (2,), + bytes.fromhex("02"), + ), + ], +) diff --git a/tests/mcproto/packets/status/test_ping.py b/tests/mcproto/packets/status/test_ping.py index bef9248b..245a03aa 100644 --- a/tests/mcproto/packets/status/test_ping.py +++ b/tests/mcproto/packets/status/test_ping.py @@ -1,36 +1,20 @@ from __future__ import annotations -from typing import Any - -import pytest - -from mcproto.buffer import Buffer from mcproto.packets.status.ping import PingPong - - -@pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ({"payload": 2806088}, bytes.fromhex("00000000002ad148")), - ({"payload": 123456}, bytes.fromhex("000000000001e240")), +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=PingPong, + fields=[("payload", int)], + test_data=[ + ( + (2806088,), + bytes.fromhex("00000000002ad148"), + ), + ( + (123456,), + bytes.fromhex("000000000001e240"), + ), ], ) -def test_serialize(kwargs: dict[str, Any], expected_bytes: list[int]): - """Test serialization of PingPong packet.""" - ping = PingPong(**kwargs) - assert ping.serialize().flush() == bytearray(expected_bytes) - - -@pytest.mark.parametrize( - ("read_bytes", "expected_out"), - [ - (bytes.fromhex("00000000002ad148"), {"payload": 2806088}), - (bytes.fromhex("000000000001e240"), {"payload": 123456}), - ], -) -def test_deserialize(read_bytes: list[int], expected_out: dict[str, Any]): - """Test deserialization of PingPong packet.""" - ping = PingPong.deserialize(Buffer(read_bytes)) - - for i, v in expected_out.items(): - assert getattr(ping, i) == v diff --git a/tests/mcproto/packets/status/test_status.py b/tests/mcproto/packets/status/test_status.py index 7f9244d5..bd40a31b 100644 --- a/tests/mcproto/packets/status/test_status.py +++ b/tests/mcproto/packets/status/test_status.py @@ -1,67 +1,36 @@ from __future__ import annotations -import json -from typing import Any +from typing import Any, Dict -import pytest - -from mcproto.buffer import Buffer from mcproto.packets.status.status import StatusResponse +from tests.helpers import gen_serializable_test - -@pytest.mark.parametrize( - ("data", "expected_bytes"), - [ +gen_serializable_test( + context=globals(), + cls=StatusResponse, + fields=[("data", Dict[str, Any])], + test_data=[ ( - { - "description": {"text": "A Minecraft Server"}, - "players": {"max": 20, "online": 0}, - "version": {"name": "1.18.1", "protocol": 757}, - }, - bytes.fromhex( - "797b226465736372697074696f6e223a7b2274657874223a2241204d696e6" - "5637261667420536572766572227d2c22706c6179657273223a7b226d6178" - "223a32302c226f6e6c696e65223a20307d2c2276657273696f6e223a7b226" - "e616d65223a22312e31382e31222c2270726f746f636f6c223a3735377d7d" + ( + { + "description": {"text": "A Minecraft Server"}, + "players": {"max": 20, "online": 0}, + "version": {"name": "1.18.1", "protocol": 757}, + }, ), - ), - ], -) -def test_serialize(data: dict[str, Any], expected_bytes: bytes): - """Test serialization of StatusResponse packet.""" - expected_buffer = Buffer(expected_bytes) - # Clear the length before the actual JSON data. JSON strings are encoded using UTF (StatusResponse uses - # `write_utf`), so `write_utf` writes the length of the string as a varint before writing the string itself. - expected_buffer.read_varint() - expected_bytes = expected_buffer.flush() - - buffer = StatusResponse(data=data).serialize() - buffer.read_varint() # Ditto - out = buffer.flush() - - assert json.loads(out) == json.loads(expected_bytes) - - -@pytest.mark.parametrize( - ("read_bytes", "expected_data"), - [ - ( bytes.fromhex( - "797b226465736372697074696f6e223a7b2274657874223a2241204d696e6" - "5637261667420536572766572227d2c22706c6179657273223a7b226d6178" - "223a32302c226f6e6c696e65223a20307d2c2276657273696f6e223a7b226" - "e616d65223a22312e31382e31222c2270726f746f636f6c223a3735377d7d" + "84017b226465736372697074696f6e223a207b2274657874223a202241204" + "d696e65637261667420536572766572227d2c2022706c6179657273223a20" + "7b226d6178223a2032302c20226f6e6c696e65223a20307d2c20227665727" + "3696f6e223a207b226e616d65223a2022312e31382e31222c202270726f74" + "6f636f6c223a203735377d7d" + # Contains spaces that are not present in the expected bytes. + # "5637261667420536572766572227d2c22706c6179657273223a7b226d6178" + # "223a32302c226f6e6c696e65223a20307d2c2276657273696f6e223a7b226" + # "e616d65223a22312e31382e31222c2270726f746f636f6c223a3735377d7d" ), - { - "description": {"text": "A Minecraft Server"}, - "players": {"max": 20, "online": 0}, - "version": {"name": "1.18.1", "protocol": 757}, - }, ), + # Unserializable data for JSON + (({"data": object()},), ValueError), ], ) -def test_deserialize(read_bytes: list[int], expected_data: dict[str, Any]): - """Test deserialization of StatusResponse packet.""" - status = StatusResponse.deserialize(Buffer(read_bytes)) - - assert expected_data == status.data diff --git a/tests/mcproto/types/test_chat.py b/tests/mcproto/types/test_chat.py index 8730151c..3c8f05a6 100644 --- a/tests/mcproto/types/test_chat.py +++ b/tests/mcproto/types/test_chat.py @@ -2,54 +2,8 @@ import pytest -from mcproto.buffer import Buffer from mcproto.types.chat import ChatMessage, RawChatMessage, RawChatMessageDict - - -@pytest.mark.parametrize( - ("data", "expected_bytes"), - [ - ( - "A Minecraft Server", - bytearray.fromhex("142241204d696e6563726166742053657276657222"), - ), - ( - {"text": "abc"}, - bytearray.fromhex("0f7b2274657874223a2022616263227d"), - ), - ( - [{"text": "abc"}, {"text": "def"}], - bytearray.fromhex("225b7b2274657874223a2022616263227d2c207b2274657874223a2022646566227d5d"), - ), - ], -) -def test_serialize(data: RawChatMessage, expected_bytes: list[int]): - """Test serialization of ChatMessage results in expected bytes.""" - output_bytes = ChatMessage(data).serialize() - assert output_bytes == expected_bytes - - -@pytest.mark.parametrize( - ("input_bytes", "data"), - [ - ( - bytearray.fromhex("142241204d696e6563726166742053657276657222"), - "A Minecraft Server", - ), - ( - bytearray.fromhex("0f7b2274657874223a2022616263227d"), - {"text": "abc"}, - ), - ( - bytearray.fromhex("225b7b2274657874223a2022616263227d2c207b2274657874223a2022646566227d5d"), - [{"text": "abc"}, {"text": "def"}], - ), - ], -) -def test_deserialize(input_bytes: list[int], data: RawChatMessage): - """Test deserialization of ChatMessage with expected bytes produces expected ChatMessage.""" - chat = ChatMessage.deserialize(Buffer(input_bytes)) - assert chat.raw == data +from tests.helpers import gen_serializable_test @pytest.mark.parametrize( @@ -93,3 +47,30 @@ def test_as_dict(raw: RawChatMessage, expected_dict: RawChatMessageDict): def test_equality(raw1: RawChatMessage, raw2: RawChatMessage, expected_result: bool): """Test comparing ChatMessage instances produces expected equality result.""" assert (ChatMessage(raw1) == ChatMessage(raw2)) is expected_result + + +gen_serializable_test( + context=globals(), + cls=ChatMessage, + fields=[("raw", RawChatMessage)], + test_data=[ + ( + ("A Minecraft Server",), + bytes.fromhex("142241204d696e6563726166742053657276657222"), + ), + ( + ({"text": "abc"},), + bytes.fromhex("0f7b2274657874223a2022616263227d"), + ), + ( + ([{"text": "abc"}, {"text": "def"}],), + bytes.fromhex("225b7b2274657874223a2022616263227d2c207b2274657874223a2022646566227d5d"), + ), + # Wrong type for raw + ((b"invalid",), TypeError), + (({"no_extra_or_text": "invalid"},), AttributeError), + (([{"no_text": "invalid"}, {"text": "Hello"}, {"extra": "World"}],), AttributeError), + # Expects a list of dicts if raw is a list + (([[]],), TypeError), + ], +) diff --git a/tests/mcproto/types/test_uuid.py b/tests/mcproto/types/test_uuid.py index 956ba06e..ffdd7e7e 100644 --- a/tests/mcproto/types/test_uuid.py +++ b/tests/mcproto/types/test_uuid.py @@ -1,36 +1,20 @@ from __future__ import annotations -import pytest - -from mcproto.buffer import Buffer from mcproto.types.uuid import UUID +from tests.helpers import gen_serializable_test - -@pytest.mark.parametrize( - ("data", "expected_bytes"), - [ - ( - "12345678-1234-5678-1234-567812345678", - bytearray.fromhex("12345678123456781234567812345678"), - ), - ], -) -def test_serialize(data: str, expected_bytes: list[int]): - """Test serialization of UUID results in expected bytes.""" - output_bytes = UUID(data).serialize() - assert output_bytes == expected_bytes - - -@pytest.mark.parametrize( - ("input_bytes", "data"), - [ - ( - bytearray.fromhex("12345678123456781234567812345678"), - "12345678-1234-5678-1234-567812345678", - ), +gen_serializable_test( + context=globals(), + cls=UUID, + fields=[("hex", str)], + test_data=[ + (("12345678-1234-5678-1234-567812345678",), bytes.fromhex("12345678123456781234567812345678")), + # Too short or too long + (("12345678-1234-5678-1234-56781234567",), ValueError), + (("12345678-1234-5678-1234-5678123456789",), ValueError), + # Not enough data in the buffer (needs 16 bytes) + (IOError, b""), + (IOError, b"\x01"), + (IOError, b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e"), ], ) -def test_deserialize(input_bytes: list[int], data: str): - """Test deserialization of UUID with expected bytes produces expected UUID.""" - uuid = UUID.deserialize(Buffer(input_bytes)) - assert str(uuid) == data diff --git a/tests/mcproto/utils/test_serializable.py b/tests/mcproto/utils/test_serializable.py new file mode 100644 index 00000000..2b1b0fbd --- /dev/null +++ b/tests/mcproto/utils/test_serializable.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import final +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.utils.abc import Serializable, dataclass +from tests.helpers import gen_serializable_test + + +@final +@dataclass +class ToyClass(Serializable): + """Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`.""" + + a: int + b: str + + @override + def serialize_to(self, buf: Buffer): + """Write the object to a buffer.""" + buf.write_varint(self.a) + buf.write_utf(self.b) + + @classmethod + @override + def deserialize(cls, buf: Buffer) -> ToyClass: + """Deserialize the object from a buffer.""" + a = buf.read_varint() + if a == 0: + raise ZeroDivisionError("a must be non-zero") + b = buf.read_utf() + return cls(a, b) + + @override + def validate(self) -> None: + """Validate the object's attributes.""" + if self.a == 0: + raise ZeroDivisionError("a must be non-zero") + if len(self.b) > 10: + raise ValueError("b must be less than 10 characters") + + +gen_serializable_test( + context=globals(), + cls=ToyClass, + fields=[("a", int), ("b", str)], + test_data=[ + ((1, "hello"), b"\x01\x05hello"), + ((2, "world"), b"\x02\x05world"), + ((0, "hello"), ZeroDivisionError), + ((1, "hello world"), ValueError), + (ZeroDivisionError, b"\x00"), + (IOError, b"\x01"), + ], +)