diff --git a/changes/273.internal.md b/changes/273.internal.md new file mode 100644 index 00000000..4096e20f --- /dev/null +++ b/changes/273.internal.md @@ -0,0 +1,90 @@ +- Changed the way `Serializable` classes are handled: + + Here is how a basic `Serializable` class looks like: + + @final + @dataclass + class ToyClass(Serializable): + """Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`.""" + + + # Attributes can be of any type + a: int + b: str + + # dataclasses.field() can be used to specify additional metadata + + def serialize_to(self, buf: Buffer): + """Write the object to a buffer.""" + buf.write_varint(self.a) + buf.write_utf(self.b) + + @classmethod + def deserialize(cls, buf: Buffer) -> ToyClass: + """Deserialize the object from a buffer.""" + a = buf.read_varint() + if a == 0: + raise ZeroDivisionError("a must be non-zero") + b = buf.read_utf() + return cls(a, b) + + def validate(self) -> None: + """Validate the object's attributes.""" + if self.a == 0: + raise ZeroDivisionError("a must be non-zero") + if len(self.b) > 10: + raise ValueError("b must be less than 10 characters") + + + The `Serializable` class must implement the following methods: + + - `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer. + - `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer. + - `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid. + +- Added a test generator for `Serializable` classes: + + The `gen_serializable_test` function generates tests for `Serializable` classes. It takes the following arguments: + + - `context`: The dictionary containing the context in which the generated test class will be placed (e.g. `globals()`). + > Dictionary updates must reflect in the context. This is the case for `globals()` but implementation-specific for `locals()`. + - `cls`: The `Serializable` class to generate tests for. + - `fields`: A list of fields where the test values will be placed. + + > In the example above, the `ToyClass` class has two fields: `a` and `b`. + + - `test_data`: A list of tuples containing either: + - `((field1_value, field2_value, ...), expected_bytes)`: The values of the fields and the expected serialized bytes. This needs to work both ways, i.e. `cls(field1_value, field2_value, ...) == cls.deserialize(expected_bytes).` + - `((field1_value, field2_value, ...), exception)`: The values of the fields and the expected exception when validating the object. + - `(exception, bytes)`: The expected exception when deserializing the bytes and the bytes to deserialize. + + The `gen_serializable_test` function generates a test class with the following tests: + + gen_serializable_test( + context=globals(), + cls=ToyClass, + fields=[("a", int), ("b", str)], + test_data=[ + ((1, "hello"), b"\x01\x05hello"), + ((2, "world"), b"\x02\x05world"), + ((0, "hello"), ZeroDivisionError), + ((1, "hello world"), ValueError), + (ZeroDivisionError, b"\x00"), + (IOError, b"\x01"), + ], + ) + + The generated test class will have the following tests: + + class TestGenToyClass: + def test_serialization(self): + # 2 subtests for the cases 1 and 2 + + def test_deserialization(self): + # 2 subtests for the cases 1 and 2 + + def test_validation(self): + # 2 subtests for the cases 3 and 4 + + def test_exceptions(self): + # 2 subtests for the cases 5 and 6 diff --git a/mcproto/packets/handshaking/handshake.py b/mcproto/packets/handshaking/handshake.py index f7b04421..72435f40 100644 --- a/mcproto/packets/handshaking/handshake.py +++ b/mcproto/packets/handshaking/handshake.py @@ -1,17 +1,18 @@ from __future__ import annotations from enum import IntEnum -from typing import ClassVar, final +from typing import ClassVar, cast, final from typing_extensions import Self, override from mcproto.buffer import Buffer from mcproto.packets.packet import GameState, ServerBoundPacket from mcproto.protocol.base_io import StructFormat +from mcproto.utils.abc import dataclass __all__ = [ - "Handshake", "NextState", + "Handshake", ] @@ -23,49 +24,34 @@ class NextState(IntEnum): @final +@dataclass class Handshake(ServerBoundPacket): - """Initializes connection between server and client. (Client -> Server).""" + """Initializes connection between server and client. (Client -> Server). + + Initialize the Handshake packet. + + :param protocol_version: Protocol version number to be used. + :param server_address: The host/address the client is connecting to. + :param server_port: The port the client is connecting to. + :param next_state: The next state for the server to move into. + """ PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.HANDSHAKING - __slots__ = ("next_state", "protocol_version", "server_address", "server_port") - - def __init__( - self, - *, - protocol_version: int, - server_address: str, - server_port: int, - next_state: NextState | int, - ): - """Initialize the Handshake packet. - - :param protocol_version: Protocol version number to be used. - :param server_address: The host/address the client is connecting to. - :param server_port: The port the client is connecting to. - :param next_state: The next state for the server to move into. - """ - if not isinstance(next_state, NextState): # next_state is int - rev_lookup = {x.value: x for x in NextState.__members__.values()} - try: - next_state = rev_lookup[next_state] - except KeyError as exc: - raise ValueError("No such next_state.") from exc - - self.protocol_version = protocol_version - self.server_address = server_address - self.server_port = server_port - self.next_state = next_state + protocol_version: int + server_address: str + server_port: int + next_state: NextState | int @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: + """Serialize the packet.""" + self.next_state = cast(NextState, self.next_state) # Handled by the validate method buf.write_varint(self.protocol_version) buf.write_utf(self.server_address) buf.write_value(StructFormat.USHORT, self.server_port) buf.write_varint(self.next_state.value) - return buf @override @classmethod @@ -76,3 +62,13 @@ def _deserialize(cls, buf: Buffer, /) -> Self: server_port=buf.read_value(StructFormat.USHORT), next_state=buf.read_varint(), ) + + @override + def validate(self) -> None: + """Validate the packet.""" + if not isinstance(self.next_state, NextState): + rev_lookup = {x.value: x for x in NextState.__members__.values()} + try: + self.next_state = rev_lookup[self.next_state] + except KeyError as exc: + raise ValueError("No such next_state.") from exc diff --git a/mcproto/packets/login/login.py b/mcproto/packets/login/login.py index be900955..bffe3eb7 100644 --- a/mcproto/packets/login/login.py +++ b/mcproto/packets/login/login.py @@ -11,6 +11,7 @@ from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket from mcproto.types.chat import ChatMessage from mcproto.types.uuid import UUID +from mcproto.utils.abc import dataclass __all__ = [ "LoginDisconnect", @@ -25,29 +26,26 @@ @final +@dataclass class LoginStart(ServerBoundPacket): - """Packet from client asking to start login process. (Client -> Server).""" + """Packet from client asking to start login process. (Client -> Server). - __slots__ = ("username", "uuid") + Initialize the LoginStart packet. + + :param username: Username of the client who sent the request. + :param uuid: UUID of the player logging in (if the player doesn't have a UUID, this can be ``None``) + """ PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, *, username: str, uuid: UUID): - """Initialize the LoginStart packet. - - :param username: Username of the client who sent the request. - :param uuid: UUID of the player logging in (if the player doesn't have a UUID, this can be ``None``) - """ - self.username = username - self.uuid = uuid + username: str + uuid: UUID @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_utf(self.username) - buf.extend(self.uuid.serialize()) - return buf + self.uuid.serialize_to(buf) @override @classmethod @@ -58,44 +56,39 @@ def _deserialize(cls, buf: Buffer, /) -> Self: @final +@dataclass class LoginEncryptionRequest(ClientBoundPacket): - """Used by the server to ask the client to encrypt the login process. (Server -> Client).""" + """Used by the server to ask the client to encrypt the login process. (Server -> Client). + + Initialize the LoginEncryptionRequest packet. - __slots__ = ("public_key", "server_id", "verify_token") + :param public_key: Server's public key. + :param verify_token: Sequence of random bytes generated by server for verification. + :param server_id: Empty on minecraft versions 1.7.X and higher (20 random chars pre 1.7). + """ PACKET_ID: ClassVar[int] = 0x01 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, *, server_id: str | None = None, public_key: RSAPublicKey, verify_token: bytes): - """Initialize the LoginEncryptionRequest packet. - - :param server_id: Empty on minecraft versions 1.7.X and higher (20 random chars pre 1.7). - :param public_key: Server's public key. - :param verify_token: Sequence of random bytes generated by server for verification. - """ - if server_id is None: - server_id = " " * 20 - - self.server_id = server_id - self.public_key = public_key - self.verify_token = verify_token + public_key: RSAPublicKey + verify_token: bytes + server_id: str | None = None @override - def serialize(self) -> Buffer: - public_key_raw = self.public_key.public_bytes(encoding=Encoding.DER, format=PublicFormat.SubjectPublicKeyInfo) + def serialize_to(self, buf: Buffer) -> None: + self.server_id = cast(str, self.server_id) - buf = Buffer() + public_key_raw = self.public_key.public_bytes(encoding=Encoding.DER, format=PublicFormat.SubjectPublicKeyInfo) buf.write_utf(self.server_id) buf.write_bytearray(public_key_raw) buf.write_bytearray(self.verify_token) - return buf @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: server_id = buf.read_utf() - public_key_raw = buf.read_bytearray() - verify_token = buf.read_bytearray() + public_key_raw = bytes(buf.read_bytearray()) + verify_token = bytes(buf.read_bytearray()) # Key type is determined by the passed key itself, we know in our case, it will # be an RSA public key, so we explicitly type-cast here. @@ -103,64 +96,65 @@ def _deserialize(cls, buf: Buffer, /) -> Self: return cls(server_id=server_id, public_key=public_key, verify_token=verify_token) + @override + def validate(self) -> None: + """Validate the packet.""" + if self.server_id is None: + self.server_id = " " * 20 + @final +@dataclass class LoginEncryptionResponse(ServerBoundPacket): - """Response from the client to :class:`LoginEncryptionRequest` packet. (Client -> Server).""" + """Response from the client to :class:`LoginEncryptionRequest` packet. (Client -> Server). + + Initialize the LoginEncryptionResponse packet. - __slots__ = ("shared_secret", "verify_token") + :param shared_secret: Shared secret value, encrypted with server's public key. + :param verify_token: Verify token value, encrypted with same public key. + """ PACKET_ID: ClassVar[int] = 0x01 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, *, shared_secret: bytes, verify_token: bytes): - """Initialize the LoginEncryptionResponse packet. - - :param shared_secret: Shared secret value, encrypted with server's public key. - :param verify_token: Verify token value, encrypted with same public key. - """ - self.shared_secret = shared_secret - self.verify_token = verify_token + shared_secret: bytes + verify_token: bytes @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: + """Serialize the packet.""" buf.write_bytearray(self.shared_secret) buf.write_bytearray(self.verify_token) - return buf @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - shared_secret = buf.read_bytearray() - verify_token = buf.read_bytearray() + shared_secret = bytes(buf.read_bytearray()) + verify_token = bytes(buf.read_bytearray()) return cls(shared_secret=shared_secret, verify_token=verify_token) @final +@dataclass class LoginSuccess(ClientBoundPacket): - """Sent by the server to denote a successful login. (Server -> Client).""" + """Sent by the server to denote a successful login. (Server -> Client). - __slots__ = ("username", "uuid") + Initialize the LoginSuccess packet. + + :param uuid: The UUID of the connecting player/client. + :param username: The username of the connecting player/client. + """ PACKET_ID: ClassVar[int] = 0x02 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, uuid: UUID, username: str): - """Initialize the LoginSuccess packet. - - :param uuid: The UUID of the connecting player/client. - :param username: The username of the connecting player/client. - """ - self.uuid = uuid - self.username = username + uuid: UUID + username: str @override - def serialize(self) -> Buffer: - buf = Buffer() - buf.extend(self.uuid.serialize()) + def serialize_to(self, buf: Buffer) -> None: + self.uuid.serialize_to(buf) buf.write_utf(self.username) - return buf @override @classmethod @@ -171,24 +165,23 @@ def _deserialize(cls, buf: Buffer, /) -> Self: @final +@dataclass class LoginDisconnect(ClientBoundPacket): - """Sent by the server to kick a player while in the login state. (Server -> Client).""" + """Sent by the server to kick a player while in the login state. (Server -> Client). + + Initialize the LoginDisconnect packet. - __slots__ = ("reason",) + :param reason: The reason for disconnection (kick). + """ PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, reason: ChatMessage): - """Initialize the LoginDisconnect packet. - - :param reason: The reason for disconnection (kick). - """ - self.reason = reason + reason: ChatMessage @override - def serialize(self) -> Buffer: - return self.reason.serialize() + def serialize_to(self, buf: Buffer) -> None: + self.reason.serialize_to(buf) @override @classmethod @@ -198,101 +191,92 @@ def _deserialize(cls, buf: Buffer, /) -> Self: @final +@dataclass class LoginPluginRequest(ClientBoundPacket): - """Sent by the server to implement a custom handshaking flow. (Server -> Client).""" + """Sent by the server to implement a custom handshaking flow. (Server -> Client). - __slots__ = ("channel", "data", "message_id") + Initialize the LoginPluginRequest. + + :param message_id: Message id, generated by the server, should be unique to the connection. + :param channel: Channel identifier, name of the plugin channel used to send data. + :param data: Data that is to be sent. + """ PACKET_ID: ClassVar[int] = 0x04 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, message_id: int, channel: str, data: bytes): - """Initialize the LoginPluginRequest. - - :param message_id: Message id, generated by the server, should be unique to the connection. - :param channel: Channel identifier, name of the plugin channel used to send data. - :param data: Data that is to be sent. - """ - self.message_id = message_id - self.channel = channel - self.data = data + message_id: int + channel: str + data: bytes @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_varint(self.message_id) buf.write_utf(self.channel) buf.write(self.data) - return buf @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: message_id = buf.read_varint() channel = buf.read_utf() - data = buf.read(buf.remaining) # All of the remaining data in the buffer + data = bytes(buf.read(buf.remaining)) # All of the remaining data in the buffer return cls(message_id, channel, data) @final +@dataclass class LoginPluginResponse(ServerBoundPacket): - """Response to LoginPluginRequest from client. (Client -> Server).""" + """Response to LoginPluginRequest from client. (Client -> Server). - __slots__ = ("data", "message_id") + Initialize the LoginPluginRequest packet. + + :param message_id: Message id, generated by the server, should be unique to the connection. + :param data: Optional response data, present if client understood request. + """ PACKET_ID: ClassVar[int] = 0x02 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, message_id: int, data: bytes | None): - """Initialize the LoginPluginRequest packet. - - :param message_id: Message id, generated by the server, should be unique to the connection. - :param data: Optional response data, present if client understood request. - """ - self.message_id = message_id - self.data = data + message_id: int + data: bytes | None @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_varint(self.message_id) buf.write_optional(self.data, buf.write) - return buf @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: message_id = buf.read_varint() - data = buf.read_optional(lambda: buf.read(buf.remaining)) + data = buf.read_optional(lambda: bytes(buf.read(buf.remaining))) return cls(message_id, data) @final +@dataclass class LoginSetCompression(ClientBoundPacket): """Sent by the server to specify whether to use compression on future packets or not (Server -> Client). - Note that this packet is optional, and if not set, the compression will not be enabled at all. - """ + Initialize the LoginSetCompression packet. + - __slots__ = ("threshold",) + :param threshold: + Maximum size of a packet before it is compressed. All packets smaller than this will remain uncompressed. + To disable compression completely, threshold can be set to -1. + + :note: This packet is optional, and if not set, the compression will not be enabled at all. + """ PACKET_ID: ClassVar[int] = 0x03 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - def __init__(self, threshold: int): - """Initialize the LoginSetCompression packet. - - :param threshold: - Maximum size of a packet before it is compressed. All packets smaller than this will remain uncompressed. - To disable compression completely, threshold can be set to -1. - """ - self.threshold = threshold + threshold: int @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_varint(self.threshold) - return buf @override @classmethod diff --git a/mcproto/packets/packet.py b/mcproto/packets/packet.py index 9db8241d..4f65a0b7 100644 --- a/mcproto/packets/packet.py +++ b/mcproto/packets/packet.py @@ -113,9 +113,9 @@ def from_packet_class(cls, packet_class: type[Packet], buffer: Buffer, message: This is a convenience constructor, picking up the necessary parameters about the identified packet from the packet class automatically (packet id, game state, ...). """ - if isinstance(packet_class, ServerBoundPacket): + if issubclass(packet_class, ServerBoundPacket): direction = PacketDirection.SERVERBOUND - elif isinstance(packet_class, ClientBoundPacket): + elif issubclass(packet_class, ClientBoundPacket): direction = PacketDirection.CLIENTBOUND else: raise TypeError( diff --git a/mcproto/packets/status/ping.py b/mcproto/packets/status/ping.py index 6ab4e355..29157e5a 100644 --- a/mcproto/packets/status/ping.py +++ b/mcproto/packets/status/ping.py @@ -7,33 +7,31 @@ from mcproto.buffer import Buffer from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket from mcproto.protocol.base_io import StructFormat +from mcproto.utils.abc import dataclass __all__ = ["PingPong"] @final +@dataclass class PingPong(ClientBoundPacket, ServerBoundPacket): - """Ping request/Pong response (Server <-> Client).""" + """Ping request/Pong response (Server <-> Client). - __slots__ = ("payload",) + Initialize the PingPong packet. + + :param payload: + Random number to test out the connection. Ideally, this number should be quite big, + however it does need to fit within the limit of a signed long long (-2 ** 63 to 2 ** 63 - 1). + """ PACKET_ID: ClassVar[int] = 0x01 GAME_STATE: ClassVar[GameState] = GameState.STATUS - def __init__(self, payload: int): - """Initialize the PingPong packet. - - :param payload: - Random number to test out the connection. Ideally, this number should be quite big, - however it does need to fit within the limit of a signed long long (-2 ** 63 to 2 ** 63 - 1). - """ - self.payload = payload + payload: int @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write_value(StructFormat.LONGLONG, self.payload) - return buf @override @classmethod diff --git a/mcproto/packets/status/status.py b/mcproto/packets/status/status.py index 4d1ad173..f9a5884c 100644 --- a/mcproto/packets/status/status.py +++ b/mcproto/packets/status/status.py @@ -7,22 +7,22 @@ from mcproto.buffer import Buffer from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket +from mcproto.utils.abc import dataclass __all__ = ["StatusRequest", "StatusResponse"] @final +@dataclass class StatusRequest(ServerBoundPacket): """Request from the client to get information on the server. (Client -> Server).""" - __slots__ = () - PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.STATUS @override - def serialize(self) -> Buffer: # pragma: no cover, nothing to test here. - return Buffer() + def serialize_to(self, buf: Buffer) -> None: + return # pragma: no cover, nothing to test here. @override @classmethod @@ -31,27 +31,24 @@ def _deserialize(cls, buf: Buffer, /) -> Self: # pragma: no cover, nothing to t @final +@dataclass class StatusResponse(ClientBoundPacket): - """Response from the server to requesting client with status data information. (Server -> Client).""" + """Response from the server to requesting client with status data information. (Server -> Client). + + Initialize the StatusResponse packet. - __slots__ = ("data",) + :param data: JSON response data sent back to the client. + """ PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.STATUS - def __init__(self, data: dict[str, Any]): - """Initialize the StatusResponse packet. - - :param data: JSON response data sent back to the client. - """ - self.data = data + data: dict[str, Any] # JSON response data sent back to the client. @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: s = json.dumps(self.data) buf.write_utf(s) - return buf @override @classmethod @@ -59,3 +56,11 @@ def _deserialize(cls, buf: Buffer, /) -> Self: s = buf.read_utf() data_ = json.loads(s) return cls(data_) + + @override + def validate(self) -> None: + # Ensure the data is serializable to JSON + try: + json.dumps(self.data) + except TypeError as exc: + raise ValueError("Data is not serializable to JSON.") from exc diff --git a/mcproto/protocol/base_io.py b/mcproto/protocol/base_io.py index d94b7b9c..a17a6ac7 100644 --- a/mcproto/protocol/base_io.py +++ b/mcproto/protocol/base_io.py @@ -155,9 +155,9 @@ async def write_bytearray(self, data: bytes, /) -> None: async def write_ascii(self, value: str, /) -> None: """Write ISO-8859-1 encoded string, with NULL (0x00) at the end to indicate string end.""" - data = bytearray(value, "ISO-8859-1") + data = bytes(value, "ISO-8859-1") await self.write(data) - await self.write(bytearray.fromhex("00")) + await self.write(bytes([0])) async def write_utf(self, value: str, /) -> None: """Write a UTF-8 encoded string, prefixed with a varint of it's size (in bytes). @@ -174,7 +174,7 @@ async def write_utf(self, value: str, /) -> None: if len(value) > 32767: raise ValueError("Maximum character limit for writing strings is 32767 characters.") - data = bytearray(value, "utf-8") + data = bytes(value, "utf-8") await self.write_varint(len(data)) await self.write(data) @@ -272,9 +272,9 @@ def write_bytearray(self, data: bytes, /) -> None: def write_ascii(self, value: str, /) -> None: """Write ISO-8859-1 encoded string, with NULL (0x00) at the end to indicate string end.""" - data = bytearray(value, "ISO-8859-1") + data = bytes(value, "ISO-8859-1") self.write(data) - self.write(bytearray.fromhex("00")) + self.write(bytes([0])) def write_utf(self, value: str, /) -> None: """Write a UTF-8 encoded string, prefixed with a varint of it's size (in bytes). @@ -291,7 +291,7 @@ def write_utf(self, value: str, /) -> None: if len(value) > 32767: raise ValueError("Maximum character limit for writing strings is 32767 characters.") - data = bytearray(value, "utf-8") + data = bytes(value, "utf-8") self.write_varint(len(data)) self.write(data) diff --git a/mcproto/types/abc.py b/mcproto/types/abc.py index 03614720..41c998c4 100644 --- a/mcproto/types/abc.py +++ b/mcproto/types/abc.py @@ -1,8 +1,8 @@ from __future__ import annotations -from mcproto.utils.abc import Serializable +from mcproto.utils.abc import Serializable, dataclass -__all__ = ["MCType"] +__all__ = ["MCType", "dataclass"] # That way we can import it from mcproto.types.abc class MCType(Serializable): diff --git a/mcproto/types/chat.py b/mcproto/types/chat.py index 46cc42cd..fe978631 100644 --- a/mcproto/types/chat.py +++ b/mcproto/types/chat.py @@ -6,7 +6,7 @@ from typing_extensions import Self, TypeAlias, override from mcproto.buffer import Buffer -from mcproto.types.abc import MCType +from mcproto.types.abc import MCType, dataclass __all__ = [ "ChatMessage", @@ -33,14 +33,12 @@ class RawChatMessageDict(TypedDict, total=False): RawChatMessage: TypeAlias = Union[RawChatMessageDict, "list[RawChatMessageDict]", str] +@dataclass @final class ChatMessage(MCType): """Minecraft chat message representation.""" - __slots__ = ("raw",) - - def __init__(self, raw: RawChatMessage): - self.raw = raw + raw: RawChatMessage def as_dict(self) -> RawChatMessageDict: """Convert received ``raw`` into a stadard :class:`dict` form.""" @@ -50,8 +48,10 @@ def as_dict(self) -> RawChatMessageDict: return RawChatMessageDict(text=self.raw) if isinstance(self.raw, dict): # pyright: ignore[reportUnnecessaryIsInstance] return self.raw - # pragma: no cover - raise TypeError(f"Found unexpected type ({self.raw.__class__!r}) ({self.raw!r}) in `raw` attribute") + + raise TypeError( # pragma: no cover + f"Found unexpected type ({self.raw.__class__!r}) ({self.raw!r}) in `raw` attribute" + ) @override def __eq__(self, other: object) -> bool: @@ -67,11 +67,9 @@ def __eq__(self, other: object) -> bool: return self.raw == other.raw @override - def serialize(self) -> Buffer: + def serialize_to(self, buf: Buffer) -> None: txt = json.dumps(self.raw) - buf = Buffer() buf.write_utf(txt) - return buf @override @classmethod @@ -79,3 +77,19 @@ def deserialize(cls, buf: Buffer, /) -> Self: txt = buf.read_utf() dct = json.loads(txt) return cls(dct) + + @override + def validate(self) -> None: + if not isinstance(self.raw, (dict, list, str)): # type: ignore[unreachable] + raise TypeError(f"Expected `raw` to be a dict, list or str, got {self.raw!r} instead") + if isinstance(self.raw, dict): # We want to keep it this way for readability + if "text" not in self.raw and "extra" not in self.raw: + raise AttributeError("Expected `raw` to have either 'text' or 'extra' key, got neither") + if isinstance(self.raw, list): + for elem in self.raw: + if not isinstance(elem, dict): # type: ignore[unreachable] + raise TypeError(f"Expected `raw` to be a list of dicts, got {elem!r} instead") + if "text" not in elem and "extra" not in elem: + raise AttributeError( + "Expected each element in `raw` to have either 'text' or 'extra' key, got neither" + ) diff --git a/mcproto/types/nbt.py b/mcproto/types/nbt.py index e3d0375e..4fc0a04b 100644 --- a/mcproto/types/nbt.py +++ b/mcproto/types/nbt.py @@ -2,14 +2,14 @@ from abc import abstractmethod from enum import IntEnum -from typing import Union, cast, Protocol, runtime_checkable +from typing import ClassVar, Union, cast, Protocol, final, runtime_checkable from collections.abc import Iterator, Mapping, Sequence from typing_extensions import TypeAlias, override from mcproto.buffer import Buffer -from mcproto.protocol.base_io import StructFormat -from mcproto.types.abc import MCType +from mcproto.protocol.base_io import StructFormat, INT_FORMATS_TYPE, FLOAT_FORMATS_TYPE +from mcproto.types.abc import MCType, dataclass __all__ = [ "ByteArrayNBT", @@ -187,13 +187,9 @@ class NBTag(MCType, NBTagConvertible): __slots__ = ("name", "payload") - def __init__(self, payload: PayloadType, name: str = ""): - self.name = name - self.payload = payload - @override def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: - """Serialize the NBT tag to a buffer. + """Serialize the NBT tag to a new buffer. :param with_type: Whether to include the type of the tag in the serialization. (Passed to :meth:`_write_header`) @@ -204,7 +200,7 @@ def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: .. note:: The ``with_type`` and ``with_name`` parameters only control the first level of serialization. """ buf = Buffer() - self.write_to(buf, with_name=with_name, with_type=with_type) + self.serialize_to(buf, with_name=with_name, with_type=with_type) return buf @override @@ -231,12 +227,16 @@ def deserialize(cls, buf: Buffer, with_name: bool = True, with_type: bool = True tag.name = name return tag + @override @abstractmethod - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the NBT tag to the buffer. + def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Serialize the NBT tag to a buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. - Implementation shortcut used in :meth:`serialize`. (Subclasses can override this, avoiding some - repetition when compared to overriding ``serialize`` directly.) + .. seealso:: :meth:`serialize` """ raise NotImplementedError @@ -261,7 +261,7 @@ def _write_header(self, buf: Buffer, with_type: bool = True, with_name: bool = T tag_type = _get_tag_type(self) buf.write_value(StructFormat.BYTE, tag_type.value) if with_name and self.name: - StringNBT(self.name).write_to(buf, with_type=False, with_name=False) + StringNBT(self.name).serialize_to(buf, with_type=False, with_name=False) @classmethod def _read_header(cls, buf: Buffer, read_type: bool = True, with_name: bool = True) -> tuple[str, NBTagType]: @@ -355,7 +355,7 @@ def from_object(data: FromObjectType, schema: FromObjectSchema, name: str = "") raise TypeError("Expected a list of integers, but a non-integer element was found.") data = cast(Union[bytes, str, int, float, "list[int]"], data) # Create the tag with the data and the name - return schema(data, name=name) + return schema(data, name=name) # type: ignore # The schema is a subclass of NBTag # Sanity check : Verify that all type schemas have been handled if not isinstance(schema, (list, tuple, dict)): @@ -367,7 +367,7 @@ def from_object(data: FromObjectType, schema: FromObjectSchema, name: str = "") if isinstance(schema, dict): # We can unpack the dictionary and create a CompoundNBT tag if not isinstance(data, dict): - raise TypeError(f"Expected a dictionary, but found {type(data).__name__}.") + raise TypeError(f"Expected a dictionary, but found a different type ({type(data).__name__}).") # Iterate over the dictionary payload: list[NBTag] = [] for key, value in data.items(): @@ -404,7 +404,9 @@ def from_object(data: FromObjectType, schema: FromObjectSchema, name: str = "") if isinstance(first_schema, (list, dict)) and not all(isinstance(item, type(first_schema)) for item in schema): raise TypeError(f"Expected a list of lists or dictionaries, but found a different type ({schema=}).") # NBTag case - if isinstance(first_schema, type) and not all(item == first_schema for item in schema): + # Now don't get me wrong, this is actually covered but the coverage tool thinks that it's missing a case with + # an empty list, which is not possible because of the previous checks + if isinstance(first_schema, type) and not all(item == first_schema for item in schema): # pragma: no cover raise TypeError(f"The schema must contain a single type of elements. ({schema=})") for item, sub_schema in zip(data, schema): @@ -474,17 +476,16 @@ def value(self) -> PayloadType: # region NBT tags types +@final +@dataclass class EndNBT(NBTag): """Sentinel tag used to mark the end of a TAG_Compound.""" - __slots__ = () - - def __init__(self): - """Create a new EndNBT tag.""" - super().__init__(0, name="") + payload: None = None + name: str = "" @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = False) -> None: + def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = False) -> None: self._write_header(buf, with_type=with_type, with_name=False) @override @@ -507,151 +508,118 @@ def value(self) -> PayloadType: return NotImplemented -class ByteNBT(NBTag): - """NBT tag representing a single byte value, represented as a signed 8-bit integer.""" +@dataclass +class _NumberNBTag(NBTag): + """Base class for NBT tags representing a number. + + This class is not meant to be used directly, but rather through its subclasses. + """ + + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = NotImplemented # type: ignore + DATA_SIZE: ClassVar[int] = NotImplemented - __slots__ = () payload: int + name: str = "" @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) - if self.payload < -(1 << 7) or self.payload >= 1 << 7: - raise OverflowError("Byte value out of range.") - - buf.write_value(StructFormat.BYTE, self.payload) + buf.write_value(self.STRUCT_FORMAT, self.payload) @override @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteNBT: + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> _NumberNBTag: name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") - if buf.remaining < 1: - raise IOError("Buffer does not contain enough data to read a byte. (Empty buffer)") + if buf.remaining < cls.DATA_SIZE: + raise IOError(f"Buffer does not contain enough data to read a {tag_type.name}.") - return ByteNBT(buf.read_value(StructFormat.BYTE), name=name) + return cls(buf.read_value(cls.STRUCT_FORMAT), name=name) - def __int__(self) -> int: - """Get the integer value of the ByteNBT tag.""" - return self.payload + @override + def validate(self) -> None: + if not isinstance(self.payload, int): # type: ignore + raise TypeError(f"Expected an int, but found {type(self.payload).__name__}.") + int_min = -(1 << (self.DATA_SIZE * 8 - 1)) + int_max = (1 << (self.DATA_SIZE * 8 - 1)) - 1 + if not int_min <= self.payload <= int_max: + raise OverflowError(f"Value out of range for a {type(self).__name__} tag.") @property @override def value(self) -> int: return self.payload + def __int__(self) -> int: + return self.payload -class ShortNBT(ByteNBT): - """NBT tag representing a short value, represented as a signed 16-bit integer.""" - - __slots__ = () - - @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) - - if self.payload < -(1 << 15) or self.payload >= 1 << 15: - raise OverflowError("Short value out of range.") - - buf.write_value(StructFormat.SHORT, self.payload) - - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ShortNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if _get_tag_type(cls) != tag_type: - raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") - - if buf.remaining < 2: - raise IOError("Buffer does not contain enough data to read a short.") - - return ShortNBT(buf.read_value(StructFormat.SHORT), name=name) +class ByteNBT(_NumberNBTag): + """NBT tag representing a single byte value, represented as a signed 8-bit integer.""" -class IntNBT(ByteNBT): - """NBT tag representing an integer value, represented as a signed 32-bit integer.""" + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.BYTE + DATA_SIZE: ClassVar[int] = 1 __slots__ = () - @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) - if self.payload < -(1 << 31) or self.payload >= 1 << 31: - raise OverflowError("Integer value out of range.") - - # No more messing around with the struct, we want 32 bits of data no matter what - buf.write_value(StructFormat.INT, self.payload) +class ShortNBT(_NumberNBTag): + """NBT tag representing a short value, represented as a signed 16-bit integer.""" - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if _get_tag_type(cls) != tag_type: - raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.SHORT + DATA_SIZE: ClassVar[int] = 2 - if buf.remaining < 4: - raise IOError("Buffer does not contain enough data to read an int.") + __slots__ = () - return IntNBT(buf.read_value(StructFormat.INT), name=name) +class IntNBT(_NumberNBTag): + """NBT tag representing an integer value, represented as a signed 32-bit integer.""" -class LongNBT(ByteNBT): - """NBT tag representing a long value, represented as a signed 64-bit integer.""" + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.INT + DATA_SIZE: ClassVar[int] = 4 __slots__ = () - @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) - if self.payload < -(1 << 63) or self.payload >= 1 << 63: - raise OverflowError("Long value out of range.") +class LongNBT(_NumberNBTag): + """NBT tag representing a long value, represented as a signed 64-bit integer.""" - # No more messing around with the struct, we want 64 bits of data no matter what - buf.write_value(StructFormat.LONGLONG, self.payload) + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.LONGLONG + DATA_SIZE: ClassVar[int] = 8 - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if _get_tag_type(cls) != tag_type: - raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + __slots__ = () - if buf.remaining < 8: - raise IOError("Buffer does not contain enough data to read a long.") - return LongNBT(buf.read_value(StructFormat.LONGLONG), name=name) +@dataclass +class _FloatingNBTag(NBTag): + """Base class for NBT tags representing a floating-point number.""" - -class FloatNBT(NBTag): - """NBT tag representing a floating-point value, represented as a 32-bit IEEE 754-2008 binary32 value.""" + STRUCT_FORMAT: ClassVar[FLOAT_FORMATS_TYPE] = NotImplemented # type: ignore + DATA_SIZE: ClassVar[int] = NotImplemented payload: float - - __slots__ = () + name: str = "" @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) - buf.write_value(StructFormat.FLOAT, self.payload) + buf.write_value(self.STRUCT_FORMAT, self.payload) @override @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> FloatNBT: + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> _FloatingNBTag: name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") - if buf.remaining < 4: - raise IOError("Buffer does not contain enough data to read a float.") + if buf.remaining < cls.DATA_SIZE: + raise IOError(f"Buffer does not contain enough data to read a {tag_type.name}.") - return FloatNBT(buf.read_value(StructFormat.FLOAT), name=name) + return cls(buf.read_value(cls.STRUCT_FORMAT), name=name) def __float__(self) -> float: - """Get the float value of the FloatNBT tag.""" return self.payload @property @@ -659,41 +627,45 @@ def __float__(self) -> float: def value(self) -> float: return self.payload + @override + def validate(self) -> None: + if isinstance(self.payload, int): + self.payload = float(self.payload) + if not isinstance(self.payload, float): + raise TypeError(f"Expected a float, but found {type(self.payload).__name__}.") + -class DoubleNBT(FloatNBT): - """NBT tag representing a double-precision floating-point value, represented as a 64-bit IEEE 754-2008 binary64.""" +@final +class FloatNBT(_FloatingNBTag): + """NBT tag representing a floating-point value, represented as a 32-bit IEEE 754-2008 binary32 value.""" + + STRUCT_FORMAT: ClassVar[FLOAT_FORMATS_TYPE] = StructFormat.FLOAT + DATA_SIZE: ClassVar[int] = 4 __slots__ = () - @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) - buf.write_value(StructFormat.DOUBLE, self.payload) - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> DoubleNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if _get_tag_type(cls) != tag_type: - raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") +@final +class DoubleNBT(_FloatingNBTag): + """NBT tag representing a double-precision floating-point value, represented as a 64-bit IEEE 754-2008 binary64.""" - if buf.remaining < 8: - raise IOError("Buffer does not contain enough data to read a double.") + STRUCT_FORMAT: ClassVar[FLOAT_FORMATS_TYPE] = StructFormat.DOUBLE + DATA_SIZE: ClassVar[int] = 8 - return DoubleNBT(buf.read_value(StructFormat.DOUBLE), name=name) + __slots__ = () +@dataclass class ByteArrayNBT(NBTag): """NBT tag representing an array of bytes. The length of the array is stored as a signed 32-bit integer.""" - __slots__ = () - payload: bytes + name: str = "" @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) - IntNBT(len(self.payload)).write_to(buf, with_type=False, with_name=False) + IntNBT(len(self.payload)).serialize_to(buf, with_type=False, with_name=False) buf.write(self.payload) @override @@ -734,23 +706,30 @@ def __repr__(self) -> str: def value(self) -> bytes: return self.payload + @override + def validate(self) -> None: + if isinstance(self.payload, bytearray): + self.payload = bytes(self.payload) + if not isinstance(self.payload, bytes): + raise TypeError(f"Expected a bytes, but found {type(self.payload).__name__}.") + +@dataclass class StringNBT(NBTag): """NBT tag representing an UTF-8 string value. The length of the string is stored as a signed 16-bit integer.""" - __slots__ = () - payload: str + name: str = "" @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) if len(self.payload) > 32767: # Check the length of the string (can't generate strings that long in tests) raise ValueError("Maximum character limit for writing strings is 32767 characters.") # pragma: no cover data = bytes(self.payload, "utf-8") - ShortNBT(len(data)).write_to(buf, with_type=False, with_name=False) + ShortNBT(len(data)).serialize_to(buf, with_type=False, with_name=False) buf.write(data) @override @@ -781,22 +760,34 @@ def __str__(self) -> str: def value(self) -> str: return self.payload + @override + def validate(self) -> None: + if not isinstance(self.payload, str): # type: ignore + raise TypeError(f"Expected a str, but found {type(self.payload).__name__}.") + if len(self.payload) > 32767: + raise ValueError("Maximum character limit for writing strings is 32767 characters.") + # Check that the string is valid UTF-8 + try: + self.payload.encode("utf-8") + except UnicodeEncodeError as exc: # pragma: no cover (don't know how to trigger) + raise ValueError("Invalid UTF-8 string.") from exc + +@dataclass class ListNBT(NBTag): """NBT tag representing a list of tags. All tags in the list must be of the same type.""" - __slots__ = () - payload: list[NBTag] + name: str = "" @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) if not self.payload: # Set the tag type to TAG_End if the list is empty - EndNBT().write_to(buf, with_name=False) - IntNBT(0).write_to(buf, with_name=False, with_type=False) + EndNBT().serialize_to(buf, with_name=False) + IntNBT(0).serialize_to(buf, with_name=False, with_type=False) return if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway @@ -806,15 +797,15 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) ) tag_type = _get_tag_type(self.payload[0]) - ByteNBT(tag_type).write_to(buf, with_name=False, with_type=False) - IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + ByteNBT(tag_type).serialize_to(buf, with_name=False, with_type=False) + IntNBT(len(self.payload)).serialize_to(buf, with_name=False, with_type=False) for tag in self.payload: if tag_type != _get_tag_type(tag): raise ValueError(f"All tags in a list must be of the same type, got tag {tag!r}") if tag.name != "": raise ValueError(f"All tags in a list must be unnamed, got tag {tag!r}") - tag.write_to(buf, with_type=False, with_name=False) + tag.serialize_to(buf, with_type=False, with_name=False) @override @classmethod @@ -896,19 +887,33 @@ def to_object( def value(self) -> list[PayloadType]: return [tag.value for tag in self.payload] + @override + def validate(self) -> None: + if not isinstance(self.payload, list): # type: ignore + raise TypeError(f"Expected a list, but found {type(self.payload).__name__}.") + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway + raise TypeError("All items in a list must be NBTags.") + if not self.payload: + return + first_tag_type = type(self.payload[0]) + if not all(type(tag) is first_tag_type for tag in self.payload): + raise TypeError("All tags in a list must be of the same type.") + if not all(tag.name == "" for tag in self.payload): + raise ValueError("All tags in a list must be unnamed.") + +@dataclass class CompoundNBT(NBTag): """NBT tag representing a compound of named tags.""" - __slots__ = () - payload: list[NBTag] + name: str = "" @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) if not self.payload: - EndNBT().write_to(buf, with_name=False, with_type=True) + EndNBT().serialize_to(buf, with_name=False, with_type=True) return if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway raise ValueError( @@ -923,8 +928,8 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) raise ValueError("All tags in a compound must have unique names.") for tag in self.payload: - tag.write_to(buf) - EndNBT().write_to(buf, with_name=False, with_type=True) + tag.serialize_to(buf) + EndNBT().serialize_to(buf, with_name=False, with_type=True) @override @classmethod @@ -1008,95 +1013,86 @@ def __eq__(self, other: object) -> bool: def value(self) -> dict[str, PayloadType]: return {tag.name: tag.value for tag in self.payload} + @override + def validate(self) -> None: + if not isinstance(self.payload, list): # type: ignore + raise TypeError(f"Expected a list, but found {type(self.payload).__name__}.") + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore + raise TypeError("All items in a compound must be NBTags.") + if not all(tag.name for tag in self.payload): + raise ValueError("All tags in a compound must be named.") + if len(self.payload) != len({tag.name for tag in self.payload}): + raise ValueError("All tags in a compound must have unique names.") + -class IntArrayNBT(NBTag): - """NBT tag representing an array of integers. The length of the array is stored as a signed 32-bit integer.""" +@dataclass +class _NumberArrayNBTag(NBTag): + """Base class for NBT tags representing an array of numbers.""" - __slots__ = () + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = NotImplemented # type: ignore + DATA_SIZE: ClassVar[int] = NotImplemented payload: list[int] + name: str = "" @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) - - if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway - raise ValueError("All items in an integer array must be integers.") - - if any(item < -(1 << 31) or item >= 1 << 31 for item in self.payload): - raise OverflowError("Integer array contains values out of range.") - - IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + IntNBT(len(self.payload)).serialize_to(buf, with_name=False, with_type=False) for i in self.payload: - IntNBT(i).write_to(buf, with_name=False, with_type=False) + buf.write_value(self.STRUCT_FORMAT, i) @override @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntArrayNBT: + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> _NumberArrayNBTag: name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != NBTagType.INT_ARRAY: - raise TypeError(f"Expected an INT_ARRAY tag, but found a different tag ({tag_type}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") length = IntNBT.read_from(buf, with_type=False, with_name=False).value - try: - payload = [IntNBT.read_from(buf, with_type is NBTagType.INT, with_name=False).value for _ in range(length)] - except IOError as exc: - raise IOError( - "Buffer does not contain enough data to read the entire integer array. (Incomplete data)" - ) from exc - return IntArrayNBT(payload, name=name) - @override - def __repr__(self) -> str: - if self.name: - return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" - if len(self.payload) < 8: - return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" - return f"{type(self).__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + if buf.remaining < length * cls.DATA_SIZE: + raise IOError(f"Buffer does not contain enough data to read the entire {tag_type.name}.") - def __iter__(self) -> Iterator[int]: - """Iterate over the integers in the array.""" - yield from self.payload + return cls([buf.read_value(cls.STRUCT_FORMAT) for _ in range(length)], name=name) + + @override + def validate(self) -> None: + if not isinstance(self.payload, list): # type: ignore + raise TypeError(f"Expected a list, but found {type(self.payload).__name__}.") + if not all(isinstance(item, int) for item in self.payload): # type: ignore + raise TypeError("All items in an integer array must be integers.") + if any( + item < -(1 << (self.DATA_SIZE * 8 - 1)) or item >= 1 << (self.DATA_SIZE * 8 - 1) for item in self.payload + ): + raise OverflowError(f"Integer array contains values out of range. ({self.payload})") @property @override def value(self) -> list[int]: return self.payload + def __iter__(self) -> Iterator[int]: + yield from self.payload -class LongArrayNBT(IntArrayNBT): - """NBT tag representing an array of longs. The length of the array is stored as a signed 32-bit integer.""" - __slots__ = () +@final +class IntArrayNBT(_NumberArrayNBTag): + """NBT tag representing an array of integers. The length of the array is stored as a signed 32-bit integer.""" - @override - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.INT + DATA_SIZE: ClassVar[int] = 4 - if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway - raise ValueError(f"All items in a long array must be integers. ({self.payload})") + __slots__ = () - if any(item < -(1 << 63) or item >= 1 << 63 for item in self.payload): - raise OverflowError(f"Long array contains values out of range. ({self.payload})") - IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) - for i in self.payload: - LongNBT(i).write_to(buf, with_name=False, with_type=False) +@final +class LongArrayNBT(_NumberArrayNBTag): + """NBT tag representing an array of longs. The length of the array is stored as a signed 32-bit integer.""" - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongArrayNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != NBTagType.LONG_ARRAY: - raise TypeError(f"Expected a LONG_ARRAY tag, but found a different tag ({tag_type}).") - length = IntNBT.read_from(buf, with_type=False, with_name=False).payload + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.LONGLONG + DATA_SIZE: ClassVar[int] = 8 - try: - payload = [LongNBT.read_from(buf, with_type=False, with_name=False).payload for _ in range(length)] - except IOError as exc: - raise IOError( - "Buffer does not contain enough data to read the entire long array. (Incomplete data)" - ) from exc - return LongArrayNBT(payload, name=name) + __slots__ = () # endregion diff --git a/mcproto/types/uuid.py b/mcproto/types/uuid.py index 97fddde8..500710b5 100644 --- a/mcproto/types/uuid.py +++ b/mcproto/types/uuid.py @@ -12,7 +12,7 @@ @final -class UUID(MCType, uuid.UUID): +class UUID(uuid.UUID, MCType): """Minecraft UUID type. In order to support potential future changes in protocol version, and implement McType, @@ -22,12 +22,11 @@ class UUID(MCType, uuid.UUID): __slots__ = () @override - def serialize(self) -> Buffer: - buf = Buffer() + def serialize_to(self, buf: Buffer) -> None: buf.write(self.bytes) - return buf @override @classmethod def deserialize(cls, buf: Buffer, /) -> Self: - return cls(bytes=bytes(buf.read(16))) + data = bytes(buf.read(16)) + return cls(bytes=data) diff --git a/mcproto/utils/abc.py b/mcproto/utils/abc.py index c2873494..5c7dcc64 100644 --- a/mcproto/utils/abc.py +++ b/mcproto/utils/abc.py @@ -1,20 +1,31 @@ from __future__ import annotations -from abc import ABC, abstractmethod +import sys +from abc import ABC, ABCMeta, abstractmethod from collections.abc import Sequence -from typing import ClassVar +from dataclasses import dataclass as _dataclass +from functools import partial +from typing import Any, ClassVar, TYPE_CHECKING from typing_extensions import Self from mcproto.buffer import Buffer -__all__ = ["RequiredParamsABCMixin", "Serializable"] +__all__ = ["RequiredParamsABCMixin", "Serializable", "dataclass"] + +if TYPE_CHECKING: + dataclass = _dataclass # The type checker needs + +if sys.version_info >= (3, 10): + dataclass = partial(_dataclass, slots=True) +else: + dataclass = _dataclass class RequiredParamsABCMixin: """Mixin class to ABCs that require certain attributes to be set in order to allow initialization. - This class performs a similar check to what :class:`~abc.ABC` already des with abstractmethods, + This class performs a similar check to what :class:`~abc.ABC` already does with abstractmethods, but for class variables. The required class variable names are set with :attr:`._REQUIRED_CLASS_VARS` class variable, which itself is automatically required. @@ -37,7 +48,7 @@ class vars which should be defined on given class directly. That means inheritan _REQUIRRED_CLASS_VARS: ClassVar[Sequence[str]] _REQUIRED_CLASS_VARS_NO_MRO: ClassVar[Sequence[str]] - def __new__(cls: type[Self], *a, **kw) -> Self: + def __new__(cls: type[Self], *a: Any, **kw: Any) -> Self: """Enforce required parameters being set for each instance of the concrete classes.""" _err_msg = f"Can't instantiate abstract {cls.__name__} class without defining " + "{!r} classvar" @@ -63,18 +74,60 @@ def __new__(cls: type[Self], *a, **kw) -> Self: return super().__new__(cls) -class Serializable(ABC): - """Base class for any type that should be (de)serializable into/from :class:`~mcproto.Buffer` data.""" +class _MetaDataclass(ABCMeta): + def __new__( + cls: type[_MetaDataclass], + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> Any: # Create the class using the super() method to ensure it is correctly formed as an ABC + new_class = super().__new__(cls, name, bases, namespace, **kwargs) + + # Check if the dataclass is already defined, if not, create it + if not hasattr(new_class, "__dataclass_fields__"): + new_class = dataclass(new_class) + + return new_class + + +class Serializable(ABC): # , metaclass=_MetaDataclass): + """Base class for any type that should be (de)serializable into/from :class:`~mcproto.Buffer` data. + + Any class that inherits from this class and adds parameters should use the :func:`~mcproto.utils.abc.dataclass` + decorator. + """ __slots__ = () - @abstractmethod + def __post_init__(self) -> None: + """Run the validation method after the object is initialized.""" + self.validate() + def serialize(self) -> Buffer: """Represent the object as a :class:`~mcproto.Buffer` (transmittable sequence of bytes).""" + self.validate() + buf = Buffer() + self.serialize_to(buf) + return buf + + @abstractmethod + def serialize_to(self, buf: Buffer, /) -> None: + """Write the object to a :class:`~mcproto.Buffer`.""" raise NotImplementedError + def validate(self) -> None: + """Validate the object's attributes, raising an exception if they are invalid. + + This will be called at the end of the object's initialization, and before serialization. + Use cast() in serialize_to() if your validation asserts that a value is of a certain type. + + By default, this method does nothing. Override it in your subclass to add validation logic. + """ + return + @classmethod @abstractmethod def deserialize(cls, buf: Buffer, /) -> Self: - """Construct the object from a :class:`~mcproto.Buffer` (transmitable sequence of bytes).""" + """Construct the object from a :class:`~mcproto.Buffer` (transmittable sequence of bytes).""" raise NotImplementedError diff --git a/tests/helpers.py b/tests/helpers.py index 44de2032..148fe973 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,14 +4,20 @@ import inspect import unittest.mock from collections.abc import Callable, Coroutine -from typing import Any, Generic, TypeVar +from typing import Any, Dict, Generic, Tuple, TypeVar, cast +import pytest from typing_extensions import ParamSpec, override +from mcproto.buffer import Buffer +from mcproto.utils.abc import Serializable + T = TypeVar("T") P = ParamSpec("P") T_Mock = TypeVar("T_Mock", bound=unittest.mock.Mock) +__all__ = ["synchronize", "SynchronizedMixin", "UnpropagatingMockMixin", "CustomMockMixin", "gen_serializable_test"] + def synchronize(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]: """Take an asynchronous function, and return a synchronous alternative. @@ -160,3 +166,174 @@ def __init__(self, **kwargs): if "spec_set" in kwargs: self.spec_set = kwargs.pop("spec_set") super().__init__(spec_set=self.spec_set, **kwargs) # type: ignore # Mixin class, this __init__ is valid + + +def gen_serializable_test( + context: dict[str, Any], + cls: type[Serializable], + fields: list[tuple[str, type]], + test_data: list[ + tuple[tuple[Any, ...], bytes] | tuple[tuple[Any, ...], type[Exception]] | tuple[type[Exception], bytes] + ], +): + """Generate tests for a serializable class. + + This function generates tests for the serialization, deserialization, validation, and deserialization error + handling + + :param context: The context to add the test functions to. This is usually `globals()`. + :param cls: The serializable class to test. + :param fields: A list of tuples containing the field names and types of the serializable class. + :param test_data: A list of test data. Each element is a tuple containing either: + - A tuple of parameters to pass to the serializable class constructor and the expected bytes after + serialization + - A tuple of parameters to pass to the serializable class constructor and the expected exception during + validation + - An exception to expect during deserialization and the bytes to deserialize + + Example usage: + ```python + @final + @dataclass + class ToyClass(Serializable): + a: int + b: str + + def serialize_to(self, buf: Buffer): + buf.write_varint(self.a) + buf.write_utf(self.b) + + @classmethod + def deserialize(cls, buf: Buffer) -> "ToyClass": + a = buf.read_varint() + b = buf.read_utf() + if len(b) > 10: + raise ValueError("b must be less than 10 characters") + return cls(a, b) + + def validate(self) -> None: + if self.a == 0: + raise ZeroDivisionError("a must be non-zero") + + gen_serializable_test( + context=globals(), + cls=ToyClass, + fields=[("a", int), ("b", str)], + test_data=[ + ((1, "hello"), b"\x01\x05hello"), + ((2, "world"), b"\x02\x05world"), + ((0, "hello"), ZeroDivisionError), + (IOError, b"\x01"), # Not enough data to deserialize + (ValueError, b"\x01\x0bhello world"), # 0b = 11 is too long + ], + ) + ``` + This will add 1 class test with 4 test functions containing the tests for serialization, deserialization, + validation, and deserialization error handling + + + """ + # Separate the test data into parameters for each test function + # This holds the parameters for the serialization and deserialization tests + parameters: list[tuple[dict[str, Any], bytes]] = [] + + # This holds the parameters for the validation tests + validation_fail: list[tuple[dict[str, Any], type[Exception]]] = [] + + # This holds the parameters for the deserialization error tests + deserialization_fail: list[tuple[bytes, type[Exception]]] = [] + + for data_or_exc, expected_bytes_or_exc in test_data: + if isinstance(data_or_exc, tuple) and isinstance(expected_bytes_or_exc, bytes): + kwargs = dict(zip([f[0] for f in fields], data_or_exc)) + parameters.append((kwargs, expected_bytes_or_exc)) + elif isinstance(data_or_exc, type) and isinstance(expected_bytes_or_exc, bytes): + deserialization_fail.append((expected_bytes_or_exc, data_or_exc)) + else: + data = cast(Tuple[Any, ...], data_or_exc) + exception = cast(type[Exception], expected_bytes_or_exc) + kwargs = dict(zip([f[0] for f in fields], data)) + validation_fail.append((kwargs, exception)) + + def generate_name(param: dict[str, Any] | bytes, i: int) -> str: + """Generate a name for the test case.""" + length = 30 + result = f"{i:02d}] : " # the first [ is added by pytest + if isinstance(param, bytes): + try: + result += str(param[:length], "utf-8") + "..." if len(param) > (length + 3) else param.decode("utf-8") + except UnicodeDecodeError: + result += repr(param[:length]) + "..." if len(param) > (length + 3) else repr(param) + else: + param = cast(Dict[str, Any], param) + begin = ", ".join(f"{k}={v}" for k, v in param.items()) + result += begin[:length] + "..." if len(begin) > (length + 3) else begin + result = result.replace("\n", "\\n").replace("\r", "\\r") + result += f" [{cls.__name__}" # the other [ is added by pytest + return result + + class TestClass: + """Test class for the generated tests.""" + + @pytest.mark.parametrize( + ("kwargs", "expected_bytes"), + parameters, + ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(parameters)), + ) + def test_serialization(self, kwargs: dict[str, Any], expected_bytes: bytes): + """Test serialization of the object.""" + obj = cls(**kwargs) + serialized_bytes = obj.serialize() + assert serialized_bytes == expected_bytes, f"{serialized_bytes} != {expected_bytes}" + + @pytest.mark.parametrize( + ("kwargs", "expected_bytes"), + parameters, + ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(parameters)), + ) + def test_deserialization(self, kwargs: dict[str, Any], expected_bytes: bytes): + """Test deserialization of the object.""" + buf = Buffer(expected_bytes) + obj = cls.deserialize(buf) + assert cls(**kwargs) == obj, f"{cls.__name__}({kwargs}) != {obj}" + assert buf.remaining == 0, f"Buffer has {buf.remaining} bytes remaining" + + @pytest.mark.parametrize( + ("kwargs", "exc"), + validation_fail, + ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(validation_fail)), + ) + def test_validation(self, kwargs: dict[str, Any], exc: type[Exception]): + """Test validation of the object.""" + with pytest.raises(exc): + cls(**kwargs) + + @pytest.mark.parametrize( + ("content", "exc"), + deserialization_fail, + ids=tuple(generate_name(content, i) for i, (content, _) in enumerate(deserialization_fail)), + ) + def test_deserialization_error(self, content: bytes, exc: type[Exception]): + """Test deserialization error handling.""" + buf = Buffer(content) + with pytest.raises(exc): + cls.deserialize(buf) + + if len(parameters) == 0: + # If there are no serialization tests, remove them + del TestClass.test_serialization + del TestClass.test_deserialization + + if len(validation_fail) == 0: + # If there are no validation tests, remove them + del TestClass.test_validation + + if len(deserialization_fail) == 0: + # If there are no deserialization error tests, remove them + del TestClass.test_deserialization_error + + # Set the names of the class + TestClass.__name__ = f"TestGen{cls.__name__}" + + # Add the test functions to the global context + context[TestClass.__name__] = TestClass diff --git a/tests/mcproto/packets/handshaking/test_handshake.py b/tests/mcproto/packets/handshaking/test_handshake.py index 42eb395e..0e34952f 100644 --- a/tests/mcproto/packets/handshaking/test_handshake.py +++ b/tests/mcproto/packets/handshaking/test_handshake.py @@ -1,101 +1,38 @@ from __future__ import annotations -from typing import Any - -import pytest - -from mcproto.buffer import Buffer from mcproto.packets.handshaking.handshake import Handshake, NextState - - -@pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"protocol_version": 757, "server_address": "mc.aircs.racing", "server_port": 25565, "next_state": 2}, - bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd02"), - ), - ( - {"protocol_version": 757, "server_address": "mc.aircs.racing", "server_port": 25565, "next_state": 1}, - bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd01"), - ), - ( - { - "protocol_version": 757, - "server_address": "hypixel.net", - "server_port": 25565, - "next_state": NextState.LOGIN, - }, - bytes.fromhex("f5050b6879706978656c2e6e657463dd02"), - ), - ( - { - "protocol_version": 757, - "server_address": "hypixel.net", - "server_port": 25565, - "next_state": NextState.STATUS, - }, - bytes.fromhex("f5050b6879706978656c2e6e657463dd01"), - ), +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=Handshake, + fields=[ + ("protocol_version", int), + ("server_address", str), + ("server_port", int), + ("next_state", NextState), ], -) -def test_serialize(kwargs: dict[str, Any], expected_bytes: list[int]): - """Test serialization of Handshake packet.""" - handshake = Handshake(**kwargs) - assert handshake.serialize().flush() == bytearray(expected_bytes) - - -@pytest.mark.parametrize( - ("read_bytes", "expected_out"), - [ + test_data=[ ( + (757, "mc.aircs.racing", 25565, NextState.LOGIN), bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd02"), - { - "protocol_version": 757, - "server_address": "mc.aircs.racing", - "server_port": 25565, - "next_state": NextState.LOGIN, - }, ), ( + (757, "mc.aircs.racing", 25565, NextState.STATUS), bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd01"), - { - "protocol_version": 757, - "server_address": "mc.aircs.racing", - "server_port": 25565, - "next_state": NextState.STATUS, - }, ), ( + (757, "hypixel.net", 25565, NextState.LOGIN), bytes.fromhex("f5050b6879706978656c2e6e657463dd02"), - { - "protocol_version": 757, - "server_address": "hypixel.net", - "server_port": 25565, - "next_state": NextState.LOGIN, - }, ), ( + (757, "hypixel.net", 25565, NextState.STATUS), bytes.fromhex("f5050b6879706978656c2e6e657463dd01"), - { - "protocol_version": 757, - "server_address": "hypixel.net", - "server_port": 25565, - "next_state": NextState.STATUS, - }, ), + # Invalid next state + ((757, "localhost", 25565, 3), ValueError), + ((757, "localhost", 25565, 4), ValueError), + ((757, "localhost", 25565, 5), ValueError), + ((757, "localhost", 25565, 6), ValueError), ], ) -def test_deserialize(read_bytes: list[int], expected_out: dict[str, Any]): - """Test deserialization of Handshake packet.""" - handshake = Handshake.deserialize(Buffer(read_bytes)) - - for i, v in expected_out.items(): - assert getattr(handshake, i) == v - - -@pytest.mark.parametrize(("state"), [3, 4, 5, 6]) -def test_invalid_state(state): - """Test initialization of Handshake packet with invalid next state raises :exc:`ValueError`.""" - with pytest.raises(ValueError): - Handshake(protocol_version=757, server_address="localhost", server_port=25565, next_state=state) diff --git a/tests/mcproto/packets/login/test_login.py b/tests/mcproto/packets/login/test_login.py index 23c21cfc..71067022 100644 --- a/tests/mcproto/packets/login/test_login.py +++ b/tests/mcproto/packets/login/test_login.py @@ -1,10 +1,5 @@ from __future__ import annotations -from typing import Any - -import pytest - -from mcproto.buffer import Buffer from mcproto.packets.login.login import ( LoginDisconnect, LoginEncryptionRequest, @@ -15,288 +10,132 @@ LoginStart, LoginSuccess, ) +from mcproto.packets.packet import InvalidPacketContentError from mcproto.types.chat import ChatMessage from mcproto.types.uuid import UUID +from tests.helpers import gen_serializable_test from tests.mcproto.test_encryption import RSA_PUBLIC_KEY +# LoginStart +gen_serializable_test( + context=globals(), + cls=LoginStart, + fields=[("username", str), ("uuid", UUID)], + test_data=[ + ( + ("ItsDrike", UUID("f70b4a42c9a04ffb92a31390c128a1b2")), + bytes.fromhex("084974734472696b65f70b4a42c9a04ffb92a31390c128a1b2"), + ), + ( + ("foobar1", UUID("7a82476416fc4e8b8686a99c775db7d3")), + bytes.fromhex("07666f6f626172317a82476416fc4e8b8686a99c775db7d3"), + ), + ], +) -class TestLoginStart: - """Collection of tests for the LoginStart packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"username": "ItsDrike", "uuid": UUID("f70b4a42c9a04ffb92a31390c128a1b2")}, - bytes.fromhex("084974734472696b65f70b4a42c9a04ffb92a31390c128a1b2"), - ), - ( - {"username": "foobar1", "uuid": UUID("7a82476416fc4e8b8686a99c775db7d3")}, - bytes.fromhex("07666f6f626172317a82476416fc4e8b8686a99c775db7d3"), - ), - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginStart packet.""" - packet = LoginStart(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("084974734472696b65f70b4a42c9a04ffb92a31390c128a1b2"), - {"username": "ItsDrike", "uuid": UUID("f70b4a42c9a04ffb92a31390c128a1b2")}, - ), - ( - bytes.fromhex("07666f6f626172317a82476416fc4e8b8686a99c775db7d3"), - {"username": "foobar1", "uuid": UUID("7a82476416fc4e8b8686a99c775db7d3")}, - ), - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginStart packet.""" - packet = LoginStart.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val - - -class TestLoginEncryptionRequest: - """Collection of tests for the LoginEncryptionRequest packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"public_key": RSA_PUBLIC_KEY, "verify_token": bytes.fromhex("9bd416ef"), "server_id": "a" * 20}, - bytes.fromhex( - "146161616161616161616161616161616161616161a20130819f300d06092a864886f70d010101050003818d003081890" - "2818100cb515109911ea3e4740d8a17a7ccd9cf226c83c7615e4a5505cd124571ee210a4ba26c7c42e15f51fcb7fa90dc" - "e6f83ebe0e163817c7d9fb1af7d981e90da2cc06ea59d01ff9fbb76b1803a0fe5af4a2c75145d89eb03e6a4aae21d2e7d" - "4c3938a298da575e12e0ae178d61a69bc0ea0b381790f182d9dba715bfb503c99d92b0203010001049bd416ef" - ), - ), - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginEncryptionRequest packet.""" - packet = LoginEncryptionRequest(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex( - "146161616161616161616161616161616161616161a20130819f300d06092a864886f70d010101050003818d003081890" - "2818100cb515109911ea3e4740d8a17a7ccd9cf226c83c7615e4a5505cd124571ee210a4ba26c7c42e15f51fcb7fa90dc" - "e6f83ebe0e163817c7d9fb1af7d981e90da2cc06ea59d01ff9fbb76b1803a0fe5af4a2c75145d89eb03e6a4aae21d2e7d" - "4c3938a298da575e12e0ae178d61a69bc0ea0b381790f182d9dba715bfb503c99d92b0203010001049bd416ef" - ), - {"public_key": RSA_PUBLIC_KEY, "verify_token": bytes.fromhex("9bd416ef"), "server_id": "a" * 20}, +# LoginEncryptionRequest +gen_serializable_test( + context=globals(), + cls=LoginEncryptionRequest, + fields=[("server_id", str), ("public_key", bytes), ("verify_token", bytes)], + test_data=[ + ( + ("a" * 20, RSA_PUBLIC_KEY, bytes.fromhex("9bd416ef")), + bytes.fromhex( + "146161616161616161616161616161616161616161a20130819f300d06092a864886f70d010101050003818d003081890" + "2818100cb515109911ea3e4740d8a17a7ccd9cf226c83c7615e4a5505cd124571ee210a4ba26c7c42e15f51fcb7fa90dc" + "e6f83ebe0e163817c7d9fb1af7d981e90da2cc06ea59d01ff9fbb76b1803a0fe5af4a2c75145d89eb03e6a4aae21d2e7d" + "4c3938a298da575e12e0ae178d61a69bc0ea0b381790f182d9dba715bfb503c99d92b0203010001049bd416ef" ), - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginEncryptionRequest packet.""" - packet = LoginEncryptionRequest.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val - - -class TestLoginEncryptionResponse: - """Collection of tests for the LoginEncryptionResponse packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"shared_secret": b"I'm shared", "verify_token": b"Token"}, - bytes.fromhex("0a49276d2073686172656405546f6b656e"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginEncryptionRespones packet.""" - packet = LoginEncryptionResponse(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("0a49276d2073686172656405546f6b656e"), - {"shared_secret": b"I'm shared", "verify_token": b"Token"}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginEncryptionRespones packet.""" - packet = LoginEncryptionResponse.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val - - -class TestLoginSuccess: - """Collection of tests for the LoginSuccess packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"uuid": UUID("f70b4a42c9a04ffb92a31390c128a1b2"), "username": "Mario"}, - bytes.fromhex("f70b4a42c9a04ffb92a31390c128a1b2054d6172696f"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginSuccess packet.""" - packet = LoginSuccess(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("f70b4a42c9a04ffb92a31390c128a1b2054d6172696f"), - {"uuid": UUID("f70b4a42c9a04ffb92a31390c128a1b2"), "username": "Mario"}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginSuccess packet.""" - packet = LoginSuccess.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val - - -class TestLoginDisconnect: - """Collection of tests for the LoginDisconnect packet.""" - - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"reason": ChatMessage("You are banned.")}, - bytes.fromhex("1122596f75206172652062616e6e65642e22"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginDisconnect packet.""" - packet = LoginDisconnect(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("1122596f75206172652062616e6e65642e22"), - {"reason": ChatMessage("You are banned.")}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginDisconnect packet.""" - packet = LoginDisconnect.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val + ), + (InvalidPacketContentError, bytes.fromhex("14")), + ], +) -class TestLoginPluginRequest: - """Collection of tests for the LoginPluginRequest packet.""" +def test_login_encryption_request_noid(): + """Test LoginEncryptionRequest without server_id.""" + packet = LoginEncryptionRequest(server_id=None, public_key=RSA_PUBLIC_KEY, verify_token=bytes.fromhex("9bd416ef")) + assert packet.server_id == " " * 20 # None is converted to an empty server id - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"message_id": 0, "channel": "xyz", "data": b"Hello"}, - bytes.fromhex("000378797a48656c6c6f"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginPluginRequest packet.""" - packet = LoginPluginRequest(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("000378797a48656c6c6f"), - {"message_id": 0, "channel": "xyz", "data": b"Hello"}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test serialization of LoginPluginRequest packet.""" - packet = LoginPluginRequest.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val +# TestLoginEncryptionResponse +gen_serializable_test( + context=globals(), + cls=LoginEncryptionResponse, + fields=[("shared_secret", bytes), ("verify_token", bytes)], + test_data=[ + ( + (b"I'm shared", b"Token"), + bytes.fromhex("0a49276d2073686172656405546f6b656e"), + ), + ], +) -class TestLoginPluginResponse: - """Collection of tests for the LoginPluginResponse packet.""" +# LoginSuccess +gen_serializable_test( + context=globals(), + cls=LoginSuccess, + fields=[("uuid", UUID), ("username", str)], + test_data=[ + ( + (UUID("f70b4a42c9a04ffb92a31390c128a1b2"), "Mario"), + bytes.fromhex("f70b4a42c9a04ffb92a31390c128a1b2054d6172696f"), + ), + ], +) - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"message_id": 0, "data": b"Hello"}, - bytes.fromhex("000148656c6c6f"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginPluginResponse packet.""" - packet = LoginPluginResponse(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) +# LoginDisconnect +gen_serializable_test( + context=globals(), + cls=LoginDisconnect, + fields=[("reason", ChatMessage)], + test_data=[ + ( + (ChatMessage("You are banned."),), + bytes.fromhex("1122596f75206172652062616e6e65642e22"), + ), + ], +) - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("000148656c6c6f"), - {"message_id": 0, "data": b"Hello"}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginPluginResponse packet.""" - packet = LoginPluginResponse.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val +# LoginPluginRequest +gen_serializable_test( + context=globals(), + cls=LoginPluginRequest, + fields=[("message_id", int), ("channel", str), ("data", bytes)], + test_data=[ + ( + (0, "xyz", b"Hello"), + bytes.fromhex("000378797a48656c6c6f"), + ), + ], +) -class TestLoginSetCompression: - """Collection of tests for the LoginSetCompression packet.""" - @pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ( - {"threshold": 2}, - bytes.fromhex("02"), - ) - ], - ) - def test_serialize(self, kwargs: dict[str, Any], expected_bytes: bytes): - """Test serialization of LoginSetCompression packet.""" - packet = LoginSetCompression(**kwargs) - assert packet.serialize().flush() == bytearray(expected_bytes) +# LoginPluginResponse +gen_serializable_test( + context=globals(), + cls=LoginPluginResponse, + fields=[("message_id", int), ("data", bytes)], + test_data=[ + ( + (0, b"Hello"), + bytes.fromhex("000148656c6c6f"), + ), + ], +) - @pytest.mark.parametrize( - ("input_bytes", "expected_args"), - [ - ( - bytes.fromhex("02"), - {"threshold": 2}, - ) - ], - ) - def test_deserialize(self, input_bytes: bytes, expected_args: dict[str, Any]): - """Test deserialization of LoginSetCompression packet.""" - packet = LoginSetCompression.deserialize(Buffer(input_bytes)) - for arg_name, val in expected_args.items(): - assert getattr(packet, arg_name) == val +# LoginSetCompression +gen_serializable_test( + context=globals(), + cls=LoginSetCompression, + fields=[("threshold", int)], + test_data=[ + ( + (2,), + bytes.fromhex("02"), + ), + ], +) diff --git a/tests/mcproto/packets/status/test_ping.py b/tests/mcproto/packets/status/test_ping.py index bef9248b..245a03aa 100644 --- a/tests/mcproto/packets/status/test_ping.py +++ b/tests/mcproto/packets/status/test_ping.py @@ -1,36 +1,20 @@ from __future__ import annotations -from typing import Any - -import pytest - -from mcproto.buffer import Buffer from mcproto.packets.status.ping import PingPong - - -@pytest.mark.parametrize( - ("kwargs", "expected_bytes"), - [ - ({"payload": 2806088}, bytes.fromhex("00000000002ad148")), - ({"payload": 123456}, bytes.fromhex("000000000001e240")), +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=PingPong, + fields=[("payload", int)], + test_data=[ + ( + (2806088,), + bytes.fromhex("00000000002ad148"), + ), + ( + (123456,), + bytes.fromhex("000000000001e240"), + ), ], ) -def test_serialize(kwargs: dict[str, Any], expected_bytes: list[int]): - """Test serialization of PingPong packet.""" - ping = PingPong(**kwargs) - assert ping.serialize().flush() == bytearray(expected_bytes) - - -@pytest.mark.parametrize( - ("read_bytes", "expected_out"), - [ - (bytes.fromhex("00000000002ad148"), {"payload": 2806088}), - (bytes.fromhex("000000000001e240"), {"payload": 123456}), - ], -) -def test_deserialize(read_bytes: list[int], expected_out: dict[str, Any]): - """Test deserialization of PingPong packet.""" - ping = PingPong.deserialize(Buffer(read_bytes)) - - for i, v in expected_out.items(): - assert getattr(ping, i) == v diff --git a/tests/mcproto/packets/status/test_status.py b/tests/mcproto/packets/status/test_status.py index 7f9244d5..bd40a31b 100644 --- a/tests/mcproto/packets/status/test_status.py +++ b/tests/mcproto/packets/status/test_status.py @@ -1,67 +1,36 @@ from __future__ import annotations -import json -from typing import Any +from typing import Any, Dict -import pytest - -from mcproto.buffer import Buffer from mcproto.packets.status.status import StatusResponse +from tests.helpers import gen_serializable_test - -@pytest.mark.parametrize( - ("data", "expected_bytes"), - [ +gen_serializable_test( + context=globals(), + cls=StatusResponse, + fields=[("data", Dict[str, Any])], + test_data=[ ( - { - "description": {"text": "A Minecraft Server"}, - "players": {"max": 20, "online": 0}, - "version": {"name": "1.18.1", "protocol": 757}, - }, - bytes.fromhex( - "797b226465736372697074696f6e223a7b2274657874223a2241204d696e6" - "5637261667420536572766572227d2c22706c6179657273223a7b226d6178" - "223a32302c226f6e6c696e65223a20307d2c2276657273696f6e223a7b226" - "e616d65223a22312e31382e31222c2270726f746f636f6c223a3735377d7d" + ( + { + "description": {"text": "A Minecraft Server"}, + "players": {"max": 20, "online": 0}, + "version": {"name": "1.18.1", "protocol": 757}, + }, ), - ), - ], -) -def test_serialize(data: dict[str, Any], expected_bytes: bytes): - """Test serialization of StatusResponse packet.""" - expected_buffer = Buffer(expected_bytes) - # Clear the length before the actual JSON data. JSON strings are encoded using UTF (StatusResponse uses - # `write_utf`), so `write_utf` writes the length of the string as a varint before writing the string itself. - expected_buffer.read_varint() - expected_bytes = expected_buffer.flush() - - buffer = StatusResponse(data=data).serialize() - buffer.read_varint() # Ditto - out = buffer.flush() - - assert json.loads(out) == json.loads(expected_bytes) - - -@pytest.mark.parametrize( - ("read_bytes", "expected_data"), - [ - ( bytes.fromhex( - "797b226465736372697074696f6e223a7b2274657874223a2241204d696e6" - "5637261667420536572766572227d2c22706c6179657273223a7b226d6178" - "223a32302c226f6e6c696e65223a20307d2c2276657273696f6e223a7b226" - "e616d65223a22312e31382e31222c2270726f746f636f6c223a3735377d7d" + "84017b226465736372697074696f6e223a207b2274657874223a202241204" + "d696e65637261667420536572766572227d2c2022706c6179657273223a20" + "7b226d6178223a2032302c20226f6e6c696e65223a20307d2c20227665727" + "3696f6e223a207b226e616d65223a2022312e31382e31222c202270726f74" + "6f636f6c223a203735377d7d" + # Contains spaces that are not present in the expected bytes. + # "5637261667420536572766572227d2c22706c6179657273223a7b226d6178" + # "223a32302c226f6e6c696e65223a20307d2c2276657273696f6e223a7b226" + # "e616d65223a22312e31382e31222c2270726f746f636f6c223a3735377d7d" ), - { - "description": {"text": "A Minecraft Server"}, - "players": {"max": 20, "online": 0}, - "version": {"name": "1.18.1", "protocol": 757}, - }, ), + # Unserializable data for JSON + (({"data": object()},), ValueError), ], ) -def test_deserialize(read_bytes: list[int], expected_data: dict[str, Any]): - """Test deserialization of StatusResponse packet.""" - status = StatusResponse.deserialize(Buffer(read_bytes)) - - assert expected_data == status.data diff --git a/tests/mcproto/types/test_chat.py b/tests/mcproto/types/test_chat.py index 8730151c..3c8f05a6 100644 --- a/tests/mcproto/types/test_chat.py +++ b/tests/mcproto/types/test_chat.py @@ -2,54 +2,8 @@ import pytest -from mcproto.buffer import Buffer from mcproto.types.chat import ChatMessage, RawChatMessage, RawChatMessageDict - - -@pytest.mark.parametrize( - ("data", "expected_bytes"), - [ - ( - "A Minecraft Server", - bytearray.fromhex("142241204d696e6563726166742053657276657222"), - ), - ( - {"text": "abc"}, - bytearray.fromhex("0f7b2274657874223a2022616263227d"), - ), - ( - [{"text": "abc"}, {"text": "def"}], - bytearray.fromhex("225b7b2274657874223a2022616263227d2c207b2274657874223a2022646566227d5d"), - ), - ], -) -def test_serialize(data: RawChatMessage, expected_bytes: list[int]): - """Test serialization of ChatMessage results in expected bytes.""" - output_bytes = ChatMessage(data).serialize() - assert output_bytes == expected_bytes - - -@pytest.mark.parametrize( - ("input_bytes", "data"), - [ - ( - bytearray.fromhex("142241204d696e6563726166742053657276657222"), - "A Minecraft Server", - ), - ( - bytearray.fromhex("0f7b2274657874223a2022616263227d"), - {"text": "abc"}, - ), - ( - bytearray.fromhex("225b7b2274657874223a2022616263227d2c207b2274657874223a2022646566227d5d"), - [{"text": "abc"}, {"text": "def"}], - ), - ], -) -def test_deserialize(input_bytes: list[int], data: RawChatMessage): - """Test deserialization of ChatMessage with expected bytes produces expected ChatMessage.""" - chat = ChatMessage.deserialize(Buffer(input_bytes)) - assert chat.raw == data +from tests.helpers import gen_serializable_test @pytest.mark.parametrize( @@ -93,3 +47,30 @@ def test_as_dict(raw: RawChatMessage, expected_dict: RawChatMessageDict): def test_equality(raw1: RawChatMessage, raw2: RawChatMessage, expected_result: bool): """Test comparing ChatMessage instances produces expected equality result.""" assert (ChatMessage(raw1) == ChatMessage(raw2)) is expected_result + + +gen_serializable_test( + context=globals(), + cls=ChatMessage, + fields=[("raw", RawChatMessage)], + test_data=[ + ( + ("A Minecraft Server",), + bytes.fromhex("142241204d696e6563726166742053657276657222"), + ), + ( + ({"text": "abc"},), + bytes.fromhex("0f7b2274657874223a2022616263227d"), + ), + ( + ([{"text": "abc"}, {"text": "def"}],), + bytes.fromhex("225b7b2274657874223a2022616263227d2c207b2274657874223a2022646566227d5d"), + ), + # Wrong type for raw + ((b"invalid",), TypeError), + (({"no_extra_or_text": "invalid"},), AttributeError), + (([{"no_text": "invalid"}, {"text": "Hello"}, {"extra": "World"}],), AttributeError), + # Expects a list of dicts if raw is a list + (([[]],), TypeError), + ], +) diff --git a/tests/mcproto/types/test_nbt.py b/tests/mcproto/types/test_nbt.py index f249410b..671f1707 100644 --- a/tests/mcproto/types/test_nbt.py +++ b/tests/mcproto/types/test_nbt.py @@ -1,7 +1,7 @@ from __future__ import annotations import struct -from typing import Any, cast +from typing import Any, Dict, List, cast import pytest @@ -19,978 +19,380 @@ LongArrayNBT, LongNBT, NBTag, - NBTagType, - PayloadType, ShortNBT, StringNBT, ) +from tests.helpers import gen_serializable_test # region EndNBT - -def test_serialize_deserialize_end(): - """Test serialization/deserialization of NBT END tag.""" - output_bytes = EndNBT().serialize() - assert output_bytes == bytearray.fromhex("00") - - buffer = Buffer() - EndNBT().write_to(buffer) - assert buffer == bytearray.fromhex("00") - - buffer.clear() - EndNBT().write_to(buffer, with_name=False) - assert buffer == bytearray.fromhex("00") - - buffer = Buffer(bytearray.fromhex("00")) - assert EndNBT.deserialize(buffer) == EndNBT() +gen_serializable_test( + context=globals(), + cls=EndNBT, + fields=[], + test_data=[ + ((), b"\x00"), + (IOError, b"\x01"), + ], +) # endregion # region Numerical NBT tests - -@pytest.mark.parametrize( - ("nbt_class", "value", "expected_bytes"), - [ - (ByteNBT, 0, bytearray.fromhex("01 00")), - (ByteNBT, 1, bytearray.fromhex("01 01")), - (ByteNBT, 127, bytearray.fromhex("01 7F")), - (ByteNBT, -128, bytearray.fromhex("01 80")), - (ByteNBT, -1, bytearray.fromhex("01 FF")), - (ByteNBT, 12, bytearray.fromhex("01 0C")), - (ShortNBT, 0, bytearray.fromhex("02 00 00")), - (ShortNBT, 1, bytearray.fromhex("02 00 01")), - (ShortNBT, 32767, bytearray.fromhex("02 7F FF")), - (ShortNBT, -32768, bytearray.fromhex("02 80 00")), - (ShortNBT, -1, bytearray.fromhex("02 FF FF")), - (ShortNBT, 12, bytearray.fromhex("02 00 0C")), - (IntNBT, 0, bytearray.fromhex("03 00 00 00 00")), - (IntNBT, 1, bytearray.fromhex("03 00 00 00 01")), - (IntNBT, 2147483647, bytearray.fromhex("03 7F FF FF FF")), - (IntNBT, -2147483648, bytearray.fromhex("03 80 00 00 00")), - (IntNBT, -1, bytearray.fromhex("03 FF FF FF FF")), - (IntNBT, 12, bytearray.fromhex("03 00 00 00 0C")), - (LongNBT, 0, bytearray.fromhex("04 00 00 00 00 00 00 00 00")), - (LongNBT, 1, bytearray.fromhex("04 00 00 00 00 00 00 00 01")), - (LongNBT, (1 << 63) - 1, bytearray.fromhex("04 7F FF FF FF FF FF FF FF")), - (LongNBT, -(1 << 63), bytearray.fromhex("04 80 00 00 00 00 00 00 00")), - (LongNBT, -1, bytearray.fromhex("04 FF FF FF FF FF FF FF FF")), - (LongNBT, 12, bytearray.fromhex("04 00 00 00 00 00 00 00 0C")), - (FloatNBT, 1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 1.0))), - (FloatNBT, 0.25, bytearray.fromhex("05") + bytes(struct.pack(">f", 0.25))), - (FloatNBT, -1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", -1.0))), - (FloatNBT, 12.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 12.0))), - (DoubleNBT, 1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 1.0))), - (DoubleNBT, 0.25, bytearray.fromhex("06") + bytes(struct.pack(">d", 0.25))), - (DoubleNBT, -1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", -1.0))), - (DoubleNBT, 12.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 12.0))), - (ByteArrayNBT, b"", bytearray.fromhex("07 00 00 00 00")), - (ByteArrayNBT, b"\x00", bytearray.fromhex("07 00 00 00 01") + b"\x00"), - (ByteArrayNBT, b"\x00\x01", bytearray.fromhex("07 00 00 00 02") + b"\x00\x01"), - ( - ByteArrayNBT, - b"\x00\x01\x02", - bytearray.fromhex("07 00 00 00 03") + b"\x00\x01\x02", - ), - ( - ByteArrayNBT, - b"\x00\x01\x02\x03", - bytearray.fromhex("07 00 00 00 04") + b"\x00\x01\x02\x03", - ), - ( - ByteArrayNBT, - b"\xff" * 1024, - bytearray.fromhex("07 00 00 04 00") + b"\xff" * 1024, - ), - ( - ByteArrayNBT, - bytes((n - 1) * n * 2 % 256 for n in range(256)), - bytearray.fromhex("07 00 00 01 00") + bytes((n - 1) * n * 2 % 256 for n in range(256)), - ), - (StringNBT, "", bytearray.fromhex("08 00 00")), - (StringNBT, "test", bytearray.fromhex("08 00 04") + b"test"), - (StringNBT, "a" * 100, bytearray.fromhex("08 00 64") + b"a" * (100)), - (StringNBT, "&à@é", bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8")), - (ListNBT, [], bytearray.fromhex("09 00 00 00 00 00")), - (ListNBT, [ByteNBT(0)], bytearray.fromhex("09 01 00 00 00 01 00")), - ( - ListNBT, - [ShortNBT(127), ShortNBT(256)], - bytearray.fromhex("09 02 00 00 00 02 00 7F 01 00"), - ), - ( - ListNBT, - [ListNBT([ByteNBT(0)]), ListNBT([IntNBT(256)])], - bytearray.fromhex("09 09 00 00 00 02 01 00 00 00 01 00 03 00 00 00 01 00 00 01 00"), - ), - (CompoundNBT, [], bytearray.fromhex("0A 00")), - ( - CompoundNBT, - [ByteNBT(0, name="test")], - bytearray.fromhex("0A") + ByteNBT(0, name="test").serialize() + b"\x00", - ), - ( - CompoundNBT, - [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], - bytearray.fromhex("0A") + ShortNBT(128, "Short").serialize() + ByteNBT(-1, "Byte").serialize() + b"\x00", - ), - ( - CompoundNBT, - [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], - bytearray.fromhex("0A") + CompoundNBT([ByteNBT(0, name="Byte")], name="test").serialize() + b"\x00", - ), - ( - CompoundNBT, - [ - CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test"), - IntNBT(-1, "Int 2"), - ], - bytearray.fromhex("0A") - + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test").serialize() - + IntNBT(-1, "Int 2").serialize() - + b"\x00", - ), - (IntArrayNBT, [], bytearray.fromhex("0B 00 00 00 00")), - (IntArrayNBT, [0], bytearray.fromhex("0B 00 00 00 01 00 00 00 00")), - ( - IntArrayNBT, - [0, 1], - bytearray.fromhex("0B 00 00 00 02 00 00 00 00 00 00 00 01"), - ), - ( - IntArrayNBT, - [1, 2, 3], - bytearray.fromhex("0B 00 00 00 03 00 00 00 01 00 00 00 02 00 00 00 03"), - ), - (IntArrayNBT, [(1 << 31) - 1], bytearray.fromhex("0B 00 00 00 01 7F FF FF FF")), - ( - IntArrayNBT, - [(1 << 31) - 1, (1 << 31) - 2], - bytearray.fromhex("0B 00 00 00 02 7F FF FF FF 7F FF FF FE"), - ), - ( - IntArrayNBT, - [-1, -2, -3], - bytearray.fromhex("0B 00 00 00 03 FF FF FF FF FF FF FF FE FF FF FF FD"), - ), - ( - IntArrayNBT, - [12] * 1024, - bytearray.fromhex("0B 00 00 04 00") + b"\x00\x00\x00\x0c" * 1024, - ), - (LongArrayNBT, [], bytearray.fromhex("0C 00 00 00 00")), - ( - LongArrayNBT, - [0], - bytearray.fromhex("0C 00 00 00 01 00 00 00 00 00 00 00 00"), - ), - ( - LongArrayNBT, - [0, 1], - bytearray.fromhex("0C 00 00 00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01"), - ), - ( - LongArrayNBT, - [1, 2, 3], - bytearray.fromhex( - "0C 00 00 00 03 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03" - ), - ), - ( - LongArrayNBT, - [(1 << 63) - 1], - bytearray.fromhex("0C 00 00 00 01 7F FF FF FF FF FF FF FF"), - ), - ( - LongArrayNBT, - [(1 << 63) - 1, (1 << 63) - 2], - bytearray.fromhex("0C 00 00 00 02 7F FF FF FF FF FF FF FF 7F FF FF FF FF FF FF FE"), - ), - ( - LongArrayNBT, - [-1, -2, -3], - bytearray.fromhex( - "0C 00 00 00 03 FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FE FF FF FF FF FF FF FF FD" - ), - ), - ( - LongArrayNBT, - [12] * 1024, - bytearray.fromhex("0C 00 00 04 00") + b"\x00\x00\x00\x00\x00\x00\x00\x0c" * 1024, - ), +gen_serializable_test( + context=globals(), + cls=ByteNBT, + fields=[("payload", int), ("name", str)], + test_data=[ + ((0, "a"), b"\x01\x00\x01a\x00"), + ((1, "test"), b"\x01\x00\x04test\x01"), + ((127, "&à@é"), b"\x01\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f"), + ((-128, "test"), b"\x01\x00\x04test\x80"), + ((-1, "a" * 100), b"\x01\x00\x64" + b"a" * 100 + b"\xff"), + # Errors + (IOError, b"\x01\x00\x04test"), + (IOError, b"\x01\x00\x04tes"), + (IOError, b"\x01\x00"), + (IOError, b"\x01"), + # Wrong type + (TypeError, b"\x02\x00\x01a\x00"), + (TypeError, b"\xff\x00\x01a\x00"), + # Out of bounds + ((1 << 7, "a"), OverflowError), + ((-(1 << 7) - 1, "a"), OverflowError), + ((1 << 8, "a"), OverflowError), + ((-(1 << 8) - 1, "a"), OverflowError), + ((1000, "a"), OverflowError), + ((1.5, "a"), TypeError), ], ) -def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType, expected_bytes: bytes): - """Test serialization/deserialization of NBT tag without name.""" - # Test serialization - output_bytes = nbt_class(value).serialize(with_name=False) - output_bytes_no_type = nbt_class(value).serialize(with_type=False, with_name=False) - assert output_bytes == expected_bytes - assert output_bytes_no_type == expected_bytes[1:] - - buffer = Buffer() - nbt_class(value).write_to(buffer, with_name=False) - assert buffer == expected_bytes - - # Test deserialization - buffer = Buffer(expected_bytes) - assert NBTag.deserialize(buffer, with_name=False) == nbt_class(value) - - buffer = Buffer(expected_bytes[1:]) - assert nbt_class.deserialize(buffer, with_type=False, with_name=False) == nbt_class(value) - - buffer = Buffer(expected_bytes) - assert nbt_class.read_from(buffer, with_name=False) == nbt_class(value) - - buffer = Buffer(expected_bytes[1:]) - assert nbt_class.read_from(buffer, with_type=False, with_name=False) == nbt_class(value) - -@pytest.mark.parametrize( - ("nbt_class", "value", "name", "expected_bytes"), - [ - ( - ByteNBT, - 0, - "test", - bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("00"), - ), - ( - ByteNBT, - 1, - "a", - bytearray.fromhex("01") + b"\x00\x01a" + bytearray.fromhex("01"), - ), - ( - ByteNBT, - 127, - "&à@é", - bytearray.fromhex("01 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F"), - ), - ( - ByteNBT, - -128, - "test", - bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("80"), - ), - ( - ByteNBT, - 12, - "a" * 100, - bytearray.fromhex("01") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("0C"), - ), - ( - ShortNBT, - 0, - "test", - bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("00 00"), - ), - ( - ShortNBT, - 1, - "a", - bytearray.fromhex("02") + b"\x00\x01a" + bytearray.fromhex("00 01"), - ), - ( - ShortNBT, - 32767, - "&à@é", - bytearray.fromhex("02 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF"), - ), - ( - ShortNBT, - -32768, - "test", - bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("80 00"), - ), - ( - ShortNBT, - 12, - "a" * 100, - bytearray.fromhex("02") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 0C"), - ), - ( - IntNBT, - 0, - "test", - bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), - ), - ( - IntNBT, - 1, - "a", - bytearray.fromhex("03") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01"), - ), - ( - IntNBT, - 2147483647, - "&à@é", - bytearray.fromhex("03 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF"), - ), - ( - IntNBT, - -2147483648, - "test", - bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00"), - ), - ( - IntNBT, - 12, - "a" * 100, - bytearray.fromhex("03") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 0C"), - ), - ( - LongNBT, - 0, - "test", - bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00 00 00 00"), - ), - ( - LongNBT, - 1, - "a", - bytearray.fromhex("04") + b"\x00\x01a" + bytearray.fromhex("00 00 00 00 00 00 00 01"), - ), - ( - LongNBT, - (1 << 63) - 1, - "&à@é", - bytearray.fromhex("04 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF FF FF FF FF"), - ), - ( - LongNBT, - -1 << 63, - "test", - bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00 00 00 00 00"), - ), - ( - LongNBT, - 12, - "a" * 100, - bytearray.fromhex("04") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 00 00 00 00 0C"), - ), - ( - FloatNBT, - 1.0, - "test", - bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 1.0)), - ), - ( - FloatNBT, - 0.25, - "a", - bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 0.25)), - ), - ( - FloatNBT, - -1.0, - "&à@é", - bytearray.fromhex("05 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0)), - ), - ( - FloatNBT, - 12.0, - "test", - bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 12.0)), - ), - ( - DoubleNBT, - 1.0, - "test", - bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 1.0)), - ), - ( - DoubleNBT, - 0.25, - "a", - bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 0.25)), - ), - ( - DoubleNBT, - -1.0, - "&à@é", - bytearray.fromhex("06 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0)), - ), - ( - DoubleNBT, - 12.0, - "test", - bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 12.0)), - ), - ( - ByteArrayNBT, - b"", - "test", - bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), - ), - ( - ByteArrayNBT, - b"\x00", - "a", - bytearray.fromhex("07") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00", - ), - ( - ByteArrayNBT, - b"\x00\x01", - "&à@é", - bytearray.fromhex("07 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("00 00 00 02") + b"\x00\x01", - ), - ( - ByteArrayNBT, - b"\x00\x01\x02", - "test", - bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 03") + b"\x00\x01\x02", - ), - ( - ByteArrayNBT, - b"\xff" * 1024, - "a" * 100, - bytearray.fromhex("07") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 04 00") + b"\xff" * 1024, - ), - ( - StringNBT, - "", - "test", - bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 00"), - ), - ( - StringNBT, - "test", - "a", - bytearray.fromhex("08") + b"\x00\x01a" + bytearray.fromhex("00 04") + b"test", - ), - ( - StringNBT, - "a" * 100, - "&à@é", - bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("00 64") + b"a" * 100, - ), - ( - StringNBT, - "&à@é", - "test", - bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 06") + bytes("&à@é", "utf-8"), - ), - ( - ListNBT, - [], - "test", - bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00"), - ), - ( - ListNBT, - [ByteNBT(-1)], - "a", - bytearray.fromhex("09") + b"\x00\x01a" + bytearray.fromhex("01 00 00 00 01 FF"), - ), - ( - ListNBT, - [ShortNBT(127), ShortNBT(256)], - "test", - bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("02 00 00 00 02 00 7F 01 00"), - ), - ( - ListNBT, - [ListNBT([ByteNBT(-1)]), ListNBT([IntNBT(256)])], - "a", - bytearray.fromhex("09") - + b"\x00\x01a" - + bytearray.fromhex("09 00 00 00 02 01 00 00 00 01 FF 03 00 00 00 01 00 00 01 00"), - ), - ( - CompoundNBT, - [], - "test", - bytearray.fromhex("0A") + b"\x00\x04test" + bytearray.fromhex("00"), - ), - ( - CompoundNBT, - [ByteNBT(0, name="Byte")], - "test", - bytearray.fromhex("0A") + b"\x00\x04test" + ByteNBT(0, name="Byte").serialize() + b"\x00", - ), - ( - CompoundNBT, - [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], - "test", - bytearray.fromhex("0A") - + b"\x00\x04test" - + ShortNBT(128, "Short").serialize() - + ByteNBT(-1, "Byte").serialize() - + b"\x00", - ), - ( - CompoundNBT, - [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], - "test", - bytearray.fromhex("0A") - + b"\x00\x04test" - + CompoundNBT([ByteNBT(0, name="Byte")], "test").serialize() - + b"\x00", - ), - ( - CompoundNBT, - [ListNBT([ByteNBT(0)], name="List")], - "test", - bytearray.fromhex("0A") + b"\x00\x04test" + ListNBT([ByteNBT(0)], name="List").serialize() + b"\x00", - ), - ( - IntArrayNBT, - [], - "test", - bytearray.fromhex("0B") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), - ), - ( - IntArrayNBT, - [0], - "a", - bytearray.fromhex("0B") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00\x00\x00\x00", - ), - ( - IntArrayNBT, - [0, 1], - "&à@é", - bytearray.fromhex("0B 00 06") - + bytes("&à@é", "utf-8") - + bytearray.fromhex("00 00 00 02") - + b"\x00\x00\x00\x00\x00\x00\x00\x01", - ), - ( - IntArrayNBT, - [1, 2, 3], - "test", - bytearray.fromhex("0B") - + b"\x00\x04test" - + bytearray.fromhex("00 00 00 03") - + b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03", - ), - ( - IntArrayNBT, - [(1 << 31) - 1], - "a" * 100, - bytearray.fromhex("0B") - + b"\x00\x64" - + b"a" * 100 - + bytearray.fromhex("00 00 00 01") - + b"\x7f\xff\xff\xff", - ), - ( - LongArrayNBT, - [], - "test", - bytearray.fromhex("0C") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), - ), - ( - LongArrayNBT, - [0], - "a", - bytearray.fromhex("0C") - + b"\x00\x01a" - + bytearray.fromhex("00 00 00 01") - + b"\x00\x00\x00\x00\x00\x00\x00\x00", - ), - ( - LongArrayNBT, - [0, 1], - "&à@é", - bytearray.fromhex("0C 00 06") - + bytes("&à@é", "utf-8") - + bytearray.fromhex("00 00 00 02") - + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - ), - ( - LongArrayNBT, - [1, 2, 3], - "test", - bytearray.fromhex("0C") - + b"\x00\x04test" - + bytearray.fromhex("00 00 00 03") - + bytearray.fromhex("00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03"), - ), - ( - LongArrayNBT, - [(1 << 63) - 1] * 100, - "a" * 100, - bytearray.fromhex("0C") - + b"\x00\x64" - + b"a" * 100 - + bytearray.fromhex("00 00 00 64") - + b"\x7f\xff\xff\xff\xff\xff\xff\xff" * 100, - ), +gen_serializable_test( + context=globals(), + cls=ShortNBT, + fields=[("payload", int), ("name", str)], + test_data=[ + ((0, "a"), b"\x02\x00\x01a\x00\x00"), + ((1, "test"), b"\x02\x00\x04test\x00\x01"), + ((32767, "&à@é"), b"\x02\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f\xff"), + ((-32768, "test"), b"\x02\x00\x04test\x80\x00"), + ((-1, "a" * 100), b"\x02\x00\x64" + b"a" * 100 + b"\xff\xff"), + # Errors + (IOError, b"\x02\x00\x04test"), + (IOError, b"\x02\x00\x04tes"), + (IOError, b"\x02\x00"), + (IOError, b"\x02"), + # Out of bounds + ((1 << 15, "a"), OverflowError), + ((-(1 << 15) - 1, "a"), OverflowError), + ((1 << 16, "a"), OverflowError), + ((-(1 << 16) - 1, "a"), OverflowError), + ((int(1e10), "a"), OverflowError), ], ) -def test_serialize_deserialize(nbt_class: type[NBTag], value: PayloadType, name: str, expected_bytes: bytes): - """Test serialization/deserialization of NBT tag with name.""" - # Test serialization - output_bytes = nbt_class(value, name).serialize() - output_bytes_no_type = nbt_class(value, name).serialize(with_type=False) - assert output_bytes == expected_bytes - assert output_bytes_no_type == expected_bytes[1:] - - buffer = Buffer() - nbt_class(value, name).write_to(buffer) - assert buffer == expected_bytes - - # Test deserialization - buffer = Buffer(expected_bytes * 2) - assert buffer.remaining == len(expected_bytes) * 2 - assert NBTag.deserialize(buffer) == nbt_class(value, name=name) - assert buffer.remaining == len(expected_bytes) - assert NBTag.deserialize(buffer) == nbt_class(value, name=name) - assert buffer.remaining == 0 - - buffer = Buffer(expected_bytes[1:]) - assert nbt_class.deserialize(buffer, with_type=False) == nbt_class(value, name=name) - buffer = Buffer(expected_bytes) - assert nbt_class.read_from(buffer) == nbt_class(value, name=name) - - buffer = Buffer(expected_bytes[1:]) - assert nbt_class.read_from(buffer, with_type=False) == nbt_class(value, name=name) - - -@pytest.mark.parametrize( - ("nbt_class", "size", "tag"), - [ - (ByteNBT, 8, NBTagType.BYTE), - (ShortNBT, 16, NBTagType.SHORT), - (IntNBT, 32, NBTagType.INT), - (LongNBT, 64, NBTagType.LONG), +gen_serializable_test( + context=globals(), + cls=IntNBT, + fields=[("payload", int), ("name", str)], + test_data=[ + ((0, "a"), b"\x03\x00\x01a\x00\x00\x00\x00"), + ((1, "test"), b"\x03\x00\x04test\x00\x00\x00\x01"), + ((2147483647, "&à@é"), b"\x03\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f\xff\xff\xff"), + ((-2147483648, "test"), b"\x03\x00\x04test\x80\x00\x00\x00"), + ((-1, "a" * 100), b"\x03\x00\x64" + b"a" * 100 + b"\xff\xff\xff\xff"), + # Errors + (IOError, b"\x03\x00\x04test"), + (IOError, b"\x03\x00\x04tes"), + (IOError, b"\x03\x00"), + (IOError, b"\x03"), + # Out of bounds + ((1 << 31, "a"), OverflowError), + ((-(1 << 31) - 1, "a"), OverflowError), + ((1 << 32, "a"), OverflowError), + ((-(1 << 32) - 1, "a"), OverflowError), + ((int(1e30), "a"), OverflowError), ], ) -def test_serialize_deserialize_numerical_fail(nbt_class: type[NBTag], size: int, tag: NBTagType): - """Test serialization/deserialization of NBT NUM tag with invalid value.""" - # Out of bounds - with pytest.raises(OverflowError): - nbt_class(1 << (size - 1)).serialize(with_name=False) - - with pytest.raises(OverflowError): - nbt_class(-(1 << (size - 1)) - 1).serialize(with_name=False) - - # Deserialization - buffer = Buffer(bytearray([tag.value + 1] + [0] * (size // 8))) - with pytest.raises(TypeError): # Tries to read a nbt_class, but it's one higher - nbt_class.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([tag.value] + [0] * ((size // 8) - 1))) - with pytest.raises(IOError): - nbt_class.read_from(buffer, with_name=False) - - buffer = Buffer(bytearray([tag.value, 0, 0] + [0] * (size // 8))) - assert nbt_class.read_from(buffer, with_name=True) == nbt_class(0) - - -# endregion - -# region FloatNBT - - -def test_serialize_deserialize_float_fail(): - """Test serialization/deserialization of NBT FLOAT tag with invalid value.""" - with pytest.raises(struct.error): - FloatNBT("test").serialize(with_name=False) - - with pytest.raises(OverflowError): - FloatNBT(1e39, "test").serialize() - - with pytest.raises(OverflowError): - FloatNBT(-1e39, "test").serialize() - - # Deserialization - buffer = Buffer(bytearray([NBTagType.BYTE] + [0] * 4)) - with pytest.raises(TypeError): # Tries to read a FloatNBT, but it's a ByteNBT - FloatNBT.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.FLOAT, 0, 0, 0])) - with pytest.raises(IOError): - FloatNBT.read_from(buffer, with_name=False) - - -# endregion -# region DoubleNBT - - -def test_serialize_deserialize_double_fail(): - """Test serialization/deserialization of NBT DOUBLE tag with invalid value.""" - with pytest.raises(struct.error): - DoubleNBT("test").serialize(with_name=False) - - # Deserialization - buffer = Buffer(bytearray([0x01] + [0] * 8)) - with pytest.raises(TypeError): # Tries to read a DoubleNBT, but it's a ByteNBT - DoubleNBT.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.DOUBLE, 0, 0, 0, 0, 0, 0, 0])) - with pytest.raises(IOError): - DoubleNBT.read_from(buffer, with_name=False) - - -# endregion -# region ByteArrayNBT - - -def test_serialize_deserialize_bytearray_fail(): - """Test serialization/deserialization of NBT BYTEARRAY tag with invalid value.""" - # Deserialization - buffer = Buffer(bytearray([0x01] + [0] * 4)) - with pytest.raises(TypeError): # Tries to read a ByteArrayNBT, but it's a ByteNBT - ByteArrayNBT.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0])) # Missing length bytes - with pytest.raises(IOError): - ByteArrayNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 1])) # Missing data bytes - with pytest.raises(IOError): - ByteArrayNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 2, 0])) # Missing data bytes - with pytest.raises(IOError): - ByteArrayNBT.read_from(buffer, with_name=False) - - # Negative length - buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0xFF, 0xFF, 0xFF, 0xFF])) # length = -1 - with pytest.raises(ValueError): - ByteArrayNBT.deserialize(buffer, with_name=False) - - -# endregion -# region StringNBT - - -def test_serialize_deserialize_string_fail(): - """Test serialization/deserialization of NBT STRING tag with invalid value.""" - # Deserialization - buffer = Buffer(bytearray([0x01, 0, 0])) - with pytest.raises(TypeError): # Tries to read a StringNBT, but it's a ByteNBT - StringNBT.deserialize(buffer, with_name=False) - - # Not enough data for the length - buffer = Buffer(bytearray([NBTagType.STRING, 0])) - with pytest.raises(IOError): - StringNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.STRING, 0, 1])) - with pytest.raises(IOError): - StringNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.STRING, 0, 2, 0])) - with pytest.raises(IOError): - StringNBT.read_from(buffer, with_name=False) - - # Negative length - buffer = Buffer(bytearray([NBTagType.STRING, 0xFF, 0xFF])) # length = -1 - with pytest.raises(ValueError): - StringNBT.deserialize(buffer, with_name=False) - - # Invalid UTF-8 - buffer = Buffer(bytearray([NBTagType.STRING, 0, 1, 0xC0, 0x80])) - with pytest.raises(UnicodeDecodeError): - StringNBT.read_from(buffer, with_name=False) +gen_serializable_test( + context=globals(), + cls=LongNBT, + fields=[("payload", int), ("name", str)], + test_data=[ + ((0, "a"), b"\x04\x00\x01a\x00\x00\x00\x00\x00\x00\x00\x00"), + ((1, "test"), b"\x04\x00\x04test\x00\x00\x00\x00\x00\x00\x00\x01"), + (((1 << 63) - 1, "&à@é"), b"\x04\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f\xff\xff\xff\xff\xff\xff\xff"), + ((-1 << 63, "test"), b"\x04\x00\x04test\x80\x00\x00\x00\x00\x00\x00\x00"), + ((-1, "a" * 100), b"\x04\x00\x64" + b"a" * 100 + b"\xff\xff\xff\xff\xff\xff\xff\xff"), + # Errors + (IOError, b"\x04\x00\x04test"), + (IOError, b"\x04\x00\x04tes"), + (IOError, b"\x04\x00"), + (IOError, b"\x04"), + # Out of bounds + ((1 << 63, "a"), OverflowError), + ((-(1 << 63) - 1, "a"), OverflowError), + ((1 << 64, "a"), OverflowError), + ((-(1 << 64) - 1, "a"), OverflowError), + ], +) # endregion -# region ListNBT - - -@pytest.mark.parametrize( - ("payload", "error"), - [ - ([ByteNBT(0), IntNBT(0)], ValueError), - ([ByteNBT(0), "test"], ValueError), - ([ByteNBT(0), None], ValueError), - ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags - ([ByteNBT(128), ByteNBT(-1)], OverflowError), # Check for error propagation +# region Floating point NBT tests +gen_serializable_test( + context=globals(), + cls=FloatNBT, + fields=[("payload", float), ("name", str)], + test_data=[ + ((1.0, "a"), b"\x05\x00\x01a" + bytes(struct.pack(">f", 1.0))), + ((0.5, "test"), b"\x05\x00\x04test" + bytes(struct.pack(">f", 0.5))), # has to be convertible to float exactly + ((-1.0, "&à@é"), b"\x05\x00\x06" + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0))), + ((12.0, "a" * 100), b"\x05\x00\x64" + b"a" * 100 + bytes(struct.pack(">f", 12.0))), + ((1, "a"), b"\x05\x00\x01a" + bytes(struct.pack(">f", 1.0))), + # Errors + (IOError, b"\x05\x00\x04test"), + (IOError, b"\x05\x00\x04tes"), + (IOError, b"\x05\x00"), + (IOError, b"\x05"), + # Wrong type + (("1.5", "a"), TypeError), ], ) -def test_serialize_list_fail(payload: PayloadType, error: type[Exception]): - """Test serialization of NBT LIST tag with invalid value.""" - with pytest.raises(error): - ListNBT(payload, "test").serialize() - - -def test_deserialize_list_fail(): - """Test deserialization of NBT LIST tag with invalid value.""" - # Wrong tag type - buffer = Buffer(bytearray([0x09, 255, 0, 0, 0, 1, 0])) - with pytest.raises(TypeError): - ListNBT.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([0x09, 1, 0, 0, 0, 1])) - with pytest.raises(IOError): - ListNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([0x09, 1, 0, 0, 0])) - with pytest.raises(IOError): - ListNBT.read_from(buffer, with_name=False) - +gen_serializable_test( + context=globals(), + cls=DoubleNBT, + fields=[("payload", float), ("name", str)], + test_data=[ + ((1.0, "a"), b"\x06\x00\x01a" + bytes(struct.pack(">d", 1.0))), + ((3.14, "test"), b"\x06\x00\x04test" + bytes(struct.pack(">d", 3.14))), + ((-1.0, "&à@é"), b"\x06\x00\x06" + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0))), + ((12.0, "a" * 100), b"\x06\x00\x64" + b"a" * 100 + bytes(struct.pack(">d", 12.0))), + # Errors + (IOError, b"\x06\x00\x04test\x01"), + (IOError, b"\x06\x00\x04test"), + (IOError, b"\x06\x00\x04tes"), + (IOError, b"\x06\x00"), + (IOError, b"\x06"), + ], +) # endregion -# region CompoundNBT - - -@pytest.mark.parametrize( - ("payload", "error"), - [ - ([ByteNBT(0, name="Hello"), IntNBT(0)], ValueError), - ([ByteNBT(0, name="hi"), "test"], ValueError), - ([ByteNBT(0, name="hi"), None], ValueError), - ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags - ( - [ByteNBT(128, name="Jello"), ByteNBT(-1, name="Bonjour")], - OverflowError, - ), # Check for error propagation +# region Variable Length NBT tests +gen_serializable_test( + context=globals(), + cls=ByteArrayNBT, + fields=[("payload", bytes), ("name", str)], + test_data=[ + ((b"", "a"), b"\x07\x00\x01a\x00\x00\x00\x00"), + ((b"\x00", "test"), b"\x07\x00\x04test\x00\x00\x00\x01\x00"), + ((b"\x00\x01", "&à@é"), b"\x07\x00\x06" + bytes("&à@é", "utf-8") + b"\x00\x00\x00\x02\x00\x01"), + ((b"\x00\x01\x02", "test"), b"\x07\x00\x04test\x00\x00\x00\x03\x00\x01\x02"), + ((b"\xff" * 1024, "a" * 100), b"\x07\x00\x64" + b"a" * 100 + b"\x00\x00\x04\x00" + b"\xff" * 1024), + ((b"Hello World", "test"), b"\x07\x00\x04test\x00\x00\x00\x0b" + b"Hello World"), + ((bytearray(b"Hello World"), "test"), b"\x07\x00\x04test\x00\x00\x00\x0b" + b"Hello World"), + # Errors + (IOError, b"\x07\x00\x04test"), + (IOError, b"\x07\x00\x04tes"), + (IOError, b"\x07\x00"), + (IOError, b"\x07"), + (IOError, b"\x07\x00\x01a\x00\x01"), + (IOError, b"\x07\x00\x01a\x00\x00\x00\xff"), + # Negative length + (ValueError, b"\x07\x00\x01a\xff\xff\xff\xff"), + # Wrong type + ((1, "a"), TypeError), ], ) -def test_serialize_compound_fail(payload: PayloadType, error: type[Exception]): - """Test serialization of NBT COMPOUND tag with invalid value.""" - with pytest.raises(error): - CompoundNBT(payload, "test").serialize() - - # Double name - with pytest.raises(ValueError): - CompoundNBT([ByteNBT(0, name="test"), ByteNBT(0, name="test")], "comp").serialize() - -def test_deseialize_compound_fail(): - """Test deserialization of NBT COMPOUND tag with invalid value.""" - # Not enough data - buffer = Buffer(bytearray([NBTagType.COMPOUND, 0x01])) - with pytest.raises(IOError): - CompoundNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.COMPOUND])) - with pytest.raises(IOError): - CompoundNBT.read_from(buffer, with_name=False) +gen_serializable_test( + context=globals(), + cls=StringNBT, + fields=[("payload", str), ("name", str)], + test_data=[ + (("", "a"), b"\x08\x00\x01a\x00\x00"), + (("test", "a"), b"\x08\x00\x01a\x00\x04" + b"test"), + (("a" * 100, "&à@é"), b"\x08\x00\x06" + bytes("&à@é", "utf-8") + b"\x00\x64" + b"a" * 100), + (("&à@é", "test"), b"\x08\x00\x04test\x00\x06" + bytes("&à@é", "utf-8")), + # Errors + (IOError, b"\x08\x00\x04test"), + (IOError, b"\x08\x00\x04tes"), + (IOError, b"\x08\x00"), + (IOError, b"\x08"), + # Negative length + (ValueError, b"\x08\xff\xff\xff\xff"), + # Unicode decode error + (UnicodeDecodeError, b"\x08\x00\x01a\x00\x01\xff"), + # String too long + (("a" * 32768, "b"), ValueError), + # Wrong type + ((1, "a"), TypeError), + ], +) - # Wrong tag type - buffer = Buffer(bytearray([15])) - with pytest.raises(TypeError): - NBTag.deserialize(buffer) +gen_serializable_test( + context=globals(), + cls=ListNBT, + fields=[("payload", list), ("name", str)], + test_data=[ + # Here we only want to test ListNBT related stuff + (([], "a"), b"\x09\x00\x01a\x00\x00\x00\x00\x00"), + (([ByteNBT(-1)], "a"), b"\x09\x00\x01a\x01\x00\x00\x00\x01\xff"), + (([ListNBT([])], "a"), b"\x09\x00\x01a\x09\x00\x00\x00\x01" + ListNBT([]).serialize()[1:]), + (([ListNBT([ByteNBT(6)])], "a"), b"\x09\x00\x01a\x09\x00\x00\x00\x01" + ListNBT([ByteNBT(6)]).serialize()[1:]), + ( + ([ListNBT([ByteNBT(-1)]), ListNBT([IntNBT(1234)])], "a"), + b"\x09\x00\x01a\x09\x00\x00\x00\x02" + + ListNBT([ByteNBT(-1)]).serialize()[1:] + + ListNBT([IntNBT(1234)]).serialize()[1:], + ), + ( + ([ListNBT([ByteNBT(-1)]), ListNBT([IntNBT(128), IntNBT(8)])], "a"), + b"\x09\x00\x01a\x09\x00\x00\x00\x02" + + ListNBT([ByteNBT(-1)]).serialize()[1:] + + ListNBT([IntNBT(128), IntNBT(8)]).serialize()[1:], + ), + # Errors + # Not enough data + (IOError, b"\x09\x00\x01a"), + (IOError, b"\x09\x00\x01a\x01"), + (IOError, b"\x09\x00\x01a\x01\x00"), + (IOError, b"\x09\x00\x01a\x01\x00\x00\x00\x01"), + (IOError, b"\x09\x00\x01a\x01\x00\x00\x00\x03\x01"), + # Invalid tag type + (TypeError, b"\x09\x00\x01a\xff\x00\x00\x01\x00"), + # Not NBTags + (([1, 2, 3], "a"), TypeError), + # Not the same tag type + (([ByteNBT(0), IntNBT(0)], "a"), TypeError), + # Contains named tags + (([ByteNBT(0, name="Byte")], "a"), ValueError), + # Wrong type + ((1, "a"), TypeError), + ], +) +gen_serializable_test( + context=globals(), + cls=CompoundNBT, + fields=[("payload", list), ("name", str)], + test_data=[ + (([], "a"), b"\x0a\x00\x01a\x00"), + (([ByteNBT(0, name="Byte")], "a"), b"\x0a\x00\x01a" + ByteNBT(0, name="Byte").serialize() + b"\x00"), + ( + ([ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], "a"), + b"\x0a\x00\x01a" + ShortNBT(128, "Short").serialize() + ByteNBT(-1, "Byte").serialize() + b"\x00", + ), + ( + ([CompoundNBT([ByteNBT(0, name="Byte")], name="test")], "a"), + b"\x0a\x00\x01a" + CompoundNBT([ByteNBT(0, name="Byte")], "test").serialize() + b"\x00", + ), + ( + ([ListNBT([ByteNBT(0)] * 3, name="List")], "a"), + b"\x0a\x00\x01a" + ListNBT([ByteNBT(0)] * 3, name="List").serialize() + b"\x00", + ), + # Errors + # Not enough data + (IOError, b"\x0a\x00\x01a"), + (IOError, b"\x0a\x00\x01a\x01"), + # All muse be NBTags + (([0, 1, 2], "a"), TypeError), + # All with a name + (([ByteNBT(0)], "a"), ValueError), + # Must be unique + (([ByteNBT(0, name="Byte"), ByteNBT(0, name="Byte")], "a"), ValueError), + # Wrong type + ((1, "a"), TypeError), + ], +) -def test_nbtag_deserialize_compound(): - """Test deserialization of NBT COMPOUND tag from the NBTag class.""" - buf = Buffer(bytearray([0x00])) - assert NBTag.deserialize(buf, with_type=False, with_name=False) == CompoundNBT([]) +gen_serializable_test( + context=globals(), + cls=IntArrayNBT, + fields=[("payload", list), ("name", str)], + test_data=[ + (([], "a"), b"\x0b\x00\x01a\x00\x00\x00\x00"), + (([0], "a"), b"\x0b\x00\x01a\x00\x00\x00\x01\x00\x00\x00\x00"), + (([0, 1], "a"), b"\x0b\x00\x01a\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01"), + (([1, 2, 3], "a"), b"\x0b\x00\x01a\x00\x00\x00\x03\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03"), + (([(1 << 31) - 1], "a"), b"\x0b\x00\x01a\x00\x00\x00\x01\x7f\xff\xff\xff"), + (([-1, -2, -3], "a"), b"\x0b\x00\x01a\x00\x00\x00\x03\xff\xff\xff\xff\xff\xff\xff\xfe\xff\xff\xff\xfd"), + # Errors + # Not enough data + (IOError, b"\x0b\x00\x01a"), + (IOError, b"\x0b\x00\x01a\x01"), + (IOError, b"\x0b\x00\x01a\x00\x00\x00\x01"), + (IOError, b"\x0b\x00\x01a\x00\x00\x00\x03\x01"), + # Must contain ints only + ((["a"], "a"), TypeError), + (([IntNBT(0)], "a"), TypeError), + (([1 << 31], "a"), OverflowError), + (([-(1 << 31) - 1], "a"), OverflowError), + # Wrong type + ((1, "a"), TypeError), + ], +) +gen_serializable_test( + context=globals(), + cls=LongArrayNBT, + fields=[("payload", list), ("name", str)], + test_data=[ + (([], "a"), b"\x0c\x00\x01a\x00\x00\x00\x00"), + (([0], "a"), b"\x0c\x00\x01a\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00"), + ( + ([0, 1], "a"), + b"\x0c\x00\x01a\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + ), + (([(1 << 63) - 1], "a"), b"\x0c\x00\x01a\x00\x00\x00\x01\x7f\xff\xff\xff\xff\xff\xff\xff"), + ( + ([-1, -2], "a"), + b"\x0c\x00\x01a\x00\x00\x00\x02\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe", + ), + # Not enough data + (IOError, b"\x0c\x00\x01a"), + (IOError, b"\x0c\x00\x01a\x01"), + (IOError, b"\x0c\x00\x01a\x00\x00\x00\x01"), + (IOError, b"\x0c\x00\x01a\x00\x00\x00\x03\x01"), + # Must contain ints only + ((["a"], "a"), TypeError), + (([LongNBT(0)], "a"), TypeError), + (([1 << 63], "a"), OverflowError), + (([-(1 << 63) - 1], "a"), OverflowError), + ], +) - buf = Buffer(bytearray.fromhex("0A 00 01 61 01 00 01 62 00 00")) - assert NBTag.deserialize(buf) == CompoundNBT([ByteNBT(0, name="b")], name="a") +# endregion +# region CompoundNBT def test_equality_compound(): """Test equality of CompoundNBT.""" - comp1 = CompoundNBT( - [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], - "comp", - ) - comp2 = CompoundNBT( - [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], - "comp", - ) + comp1 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") assert comp1 == comp2 comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2")], "comp") assert comp1 != comp2 - comp2 = CompoundNBT( - [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], - "comp", - ) + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], "comp") assert comp1 != comp2 - comp2 = CompoundNBT( - [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], - "comp2", - ) + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp2") assert comp1 != comp2 assert comp1 != ByteNBT(0, name="comp") # endregion -# region IntArrayNBT - - -@pytest.mark.parametrize( - ("payload", "error"), - [ - ([0, "test"], ValueError), - ([0, None], ValueError), - ([0, 1 << 31], OverflowError), - ([0, -(1 << 31) - 1], OverflowError), - ], -) -def test_serialize_intarray_fail(payload: PayloadType, error: type[Exception]): - """Test serialization of NBT INTARRAY tag with invalid value.""" - with pytest.raises(error): - IntArrayNBT(payload, "test").serialize() - - -def test_deserialize_intarray_fail(): - """Test deserialization of NBT INTARRAY tag with invalid value.""" - # Not enough data for 1 element - buffer = Buffer(bytearray([0x0B, 0, 0, 0, 1, 0, 0, 0])) - with pytest.raises(IOError): - IntArrayNBT.deserialize(buffer, with_name=False) - - # Not enough data for the size - buffer = Buffer(bytearray([0x0B, 0, 0, 0])) - with pytest.raises(IOError): - IntArrayNBT.read_from(buffer, with_name=False) - - # Not enough data to start the 2nd element - buffer = Buffer(bytearray([0x0B, 0, 0, 0, 2, 1, 0, 0, 0])) - with pytest.raises(IOError): - IntArrayNBT.read_from(buffer, with_name=False) - - -# endregion -# region LongArrayNBT - - -@pytest.mark.parametrize( - ("payload", "error"), - [ - ([0, "test"], ValueError), - ([0, None], ValueError), - ([0, 1 << 63], OverflowError), - ([0, -(1 << 63) - 1], OverflowError), - ], -) -def test_serialize_deserialize_longarray_fail(payload: PayloadType, error: type[Exception]): - """Test serialization/deserialization of NBT LONGARRAY tag with invalid value.""" - with pytest.raises(error): - LongArrayNBT(payload, "test").serialize() - -def test_deserialize_longarray_fail(): - """Test deserialization of NBT LONGARRAY tag with invalid value.""" - # Not enough data for 1 element - buffer = Buffer(bytearray([0x0C, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])) - with pytest.raises(IOError): - LongArrayNBT.deserialize(buffer, with_name=False) +# region ListNBT - # Not enough data for the size - buffer = Buffer(bytearray([0x0C, 0, 0, 0])) - with pytest.raises(IOError): - LongArrayNBT.read_from(buffer, with_name=False) - # Not enough data to start the 2nd element - buffer = Buffer(bytearray([0x0C, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0])) - with pytest.raises(IOError): - LongArrayNBT.read_from(buffer, with_name=False) +def test_intarray_negative_length(): + """Test IntArray with negative length.""" + buffer = Buffer(b"\x0b\x00\x01a\xff\xff\xff\xff") + assert IntArrayNBT.read_from(buffer) == IntArrayNBT([], "a") # endregion @@ -1019,7 +421,7 @@ def test_nbt_helloworld(): def test_nbt_bigfile(): """Test serialization/deserialization of a big NBT tag. - Slightly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. + Slighly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. Source data: https://wiki.vg/NBT#Example. """ data = "0a00054c6576656c0400086c6f6e67546573747fffffffffffffff02000973686f7274546573747fff08000a737472696e6754657374002948454c4c4f20574f524c4420544849532049532041205445535420535452494e4720c385c384c39621050009666c6f6174546573743eff1832030007696e74546573747fffffff0a00146e657374656420636f6d706f756e6420746573740a000368616d0800046e616d65000648616d70757305000576616c75653f400000000a00036567670800046e616d6500074567676265727405000576616c75653f00000000000c000f6c6973745465737420286c6f6e672900000005000000000000000b000000000000000c000000000000000d000000000000000e7fffffffffffffff0b000e6c697374546573742028696e7429000000047fffffff7ffffffe7ffffffd7ffffffc0900136c697374546573742028636f6d706f756e64290a000000020800046e616d65000f436f6d706f756e642074616720233004000a637265617465642d6f6e000001265237d58d000800046e616d65000f436f6d706f756e642074616720233104000a637265617465642d6f6e000001265237d58d0001000862797465546573747f07006562797465417272617954657374202874686520666972737420313030302076616c756573206f6620286e2a6e2a3235352b6e2a3729253130302c207374617274696e672077697468206e3d302028302c2036322c2033342c2031362c20382c202e2e2e2929000003e8003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a063005000a646f75626c65546573743efc000000" # noqa: E501 @@ -1084,8 +486,8 @@ def check_equality(self: object, other: object) -> bool: if type(self) != type(other): return False if isinstance(self, dict): - self = cast("dict[Any, Any]", self) - other = cast("dict[Any, Any]", other) + self = cast(Dict[Any, Any], self) + other = cast(Dict[Any, Any], other) if len(self) != len(other): return False for key in self: @@ -1095,8 +497,8 @@ def check_equality(self: object, other: object) -> bool: return False return True if isinstance(self, list): - self = cast("list[Any]", self) - other = cast("list[Any]", other) + self = cast(List[Any], self) + other = cast(List[Any], other) if len(self) != len(other): return False return all(check_equality(self[i], other[i]) for i in range(len(self))) @@ -1112,8 +514,6 @@ def check_equality(self: object, other: object) -> bool: # endregion # region Edge cases - - def test_from_object_morecases(): """Test from_object with more edge cases.""" @@ -1199,7 +599,7 @@ def to_nbt(self, name: str = "") -> NBTag: ), ([1], [], ValueError, "The schema and the data must have the same length."), # Schema is a dict, data is not - (["test"], {"test": ByteNBT}, TypeError, "Expected a dictionary, but found list."), + (["test"], {"test": ByteNBT}, TypeError, "Expected a dictionary, but found a different type."), # Schema is not a dict, list or subclass of NBTagConvertible ( ["test"], @@ -1246,7 +646,7 @@ def to_nbt(self, name: str = "") -> NBTag: {"test": object()}, ByteNBT, TypeError, - r"Expected one of \(bytes, str, int, float, list\), but found object.", + "Expected one of \\(bytes, str, int, float, list\\), but found object.", ), # The data is a list but not all elements are ints ( @@ -1390,6 +790,3 @@ def test_wrong_type(buffer_content: str, tag_type: type[NBTag]): buffer = Buffer(bytearray.fromhex(buffer_content)) with pytest.raises(TypeError): tag_type.read_from(buffer, with_name=False) - - -# endregion diff --git a/tests/mcproto/types/test_uuid.py b/tests/mcproto/types/test_uuid.py index 956ba06e..ffdd7e7e 100644 --- a/tests/mcproto/types/test_uuid.py +++ b/tests/mcproto/types/test_uuid.py @@ -1,36 +1,20 @@ from __future__ import annotations -import pytest - -from mcproto.buffer import Buffer from mcproto.types.uuid import UUID +from tests.helpers import gen_serializable_test - -@pytest.mark.parametrize( - ("data", "expected_bytes"), - [ - ( - "12345678-1234-5678-1234-567812345678", - bytearray.fromhex("12345678123456781234567812345678"), - ), - ], -) -def test_serialize(data: str, expected_bytes: list[int]): - """Test serialization of UUID results in expected bytes.""" - output_bytes = UUID(data).serialize() - assert output_bytes == expected_bytes - - -@pytest.mark.parametrize( - ("input_bytes", "data"), - [ - ( - bytearray.fromhex("12345678123456781234567812345678"), - "12345678-1234-5678-1234-567812345678", - ), +gen_serializable_test( + context=globals(), + cls=UUID, + fields=[("hex", str)], + test_data=[ + (("12345678-1234-5678-1234-567812345678",), bytes.fromhex("12345678123456781234567812345678")), + # Too short or too long + (("12345678-1234-5678-1234-56781234567",), ValueError), + (("12345678-1234-5678-1234-5678123456789",), ValueError), + # Not enough data in the buffer (needs 16 bytes) + (IOError, b""), + (IOError, b"\x01"), + (IOError, b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e"), ], ) -def test_deserialize(input_bytes: list[int], data: str): - """Test deserialization of UUID with expected bytes produces expected UUID.""" - uuid = UUID.deserialize(Buffer(input_bytes)) - assert str(uuid) == data diff --git a/tests/mcproto/utils/test_serializable.py b/tests/mcproto/utils/test_serializable.py new file mode 100644 index 00000000..2b1b0fbd --- /dev/null +++ b/tests/mcproto/utils/test_serializable.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import final +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.utils.abc import Serializable, dataclass +from tests.helpers import gen_serializable_test + + +@final +@dataclass +class ToyClass(Serializable): + """Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`.""" + + a: int + b: str + + @override + def serialize_to(self, buf: Buffer): + """Write the object to a buffer.""" + buf.write_varint(self.a) + buf.write_utf(self.b) + + @classmethod + @override + def deserialize(cls, buf: Buffer) -> ToyClass: + """Deserialize the object from a buffer.""" + a = buf.read_varint() + if a == 0: + raise ZeroDivisionError("a must be non-zero") + b = buf.read_utf() + return cls(a, b) + + @override + def validate(self) -> None: + """Validate the object's attributes.""" + if self.a == 0: + raise ZeroDivisionError("a must be non-zero") + if len(self.b) > 10: + raise ValueError("b must be less than 10 characters") + + +gen_serializable_test( + context=globals(), + cls=ToyClass, + fields=[("a", int), ("b", str)], + test_data=[ + ((1, "hello"), b"\x01\x05hello"), + ((2, "world"), b"\x02\x05world"), + ((0, "hello"), ZeroDivisionError), + ((1, "hello world"), ValueError), + (ZeroDivisionError, b"\x00"), + (IOError, b"\x01"), + ], +)