From 1f26da5a03b622317e3c4a1231c1511bb03abc68 Mon Sep 17 00:00:00 2001 From: Alexis Rossfelder Date: Mon, 6 May 2024 20:27:44 +0200 Subject: [PATCH] Add more types * Angle * Bitset and FixedBitset * Position * Vec3 * Quaternion * Slot * Identifier * TextComponent - Rename ChatMessage to JSONTextComponent --- mcproto/packets/login/login.py | 6 +- mcproto/types/angle.py | 75 ++ mcproto/types/bitset.py | 193 ++++ mcproto/types/chat.py | 121 +- mcproto/types/identifier.py | 51 + mcproto/types/nbt.py | 354 +++--- mcproto/types/quaternion.py | 91 ++ mcproto/types/slot.py | 78 ++ mcproto/types/vec3.py | 164 +++ tests/mcproto/packets/login/test_login.py | 6 +- tests/mcproto/types/test_angle.py | 68 ++ tests/mcproto/types/test_bitset.py | 204 ++++ tests/mcproto/types/test_chat.py | 62 +- tests/mcproto/types/test_identifier.py | 19 + tests/mcproto/types/test_nbt.py | 1261 ++++++--------------- tests/mcproto/types/test_quaternion.py | 121 ++ tests/mcproto/types/test_slot.py | 32 + tests/mcproto/types/test_vec3.py | 214 ++++ 18 files changed, 1977 insertions(+), 1143 deletions(-) create mode 100644 mcproto/types/angle.py create mode 100644 mcproto/types/bitset.py create mode 100644 mcproto/types/identifier.py create mode 100644 mcproto/types/quaternion.py create mode 100644 mcproto/types/slot.py create mode 100644 mcproto/types/vec3.py create mode 100644 tests/mcproto/types/test_angle.py create mode 100644 tests/mcproto/types/test_bitset.py create mode 100644 tests/mcproto/types/test_identifier.py create mode 100644 tests/mcproto/types/test_quaternion.py create mode 100644 tests/mcproto/types/test_slot.py create mode 100644 tests/mcproto/types/test_vec3.py diff --git a/mcproto/packets/login/login.py b/mcproto/packets/login/login.py index bffe3eb7..21171cdf 100644 --- a/mcproto/packets/login/login.py +++ b/mcproto/packets/login/login.py @@ -9,7 +9,7 @@ from mcproto.buffer import Buffer from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket -from mcproto.types.chat import ChatMessage +from mcproto.types.chat import JSONTextComponent from mcproto.types.uuid import UUID from mcproto.utils.abc import dataclass @@ -177,7 +177,7 @@ class LoginDisconnect(ClientBoundPacket): PACKET_ID: ClassVar[int] = 0x00 GAME_STATE: ClassVar[GameState] = GameState.LOGIN - reason: ChatMessage + reason: JSONTextComponent @override def serialize_to(self, buf: Buffer) -> None: @@ -186,7 +186,7 @@ def serialize_to(self, buf: Buffer) -> None: @override @classmethod def _deserialize(cls, buf: Buffer, /) -> Self: - reason = ChatMessage.deserialize(buf) + reason = JSONTextComponent.deserialize(buf) return cls(reason) diff --git a/mcproto/types/angle.py b/mcproto/types/angle.py new file mode 100644 index 00000000..d1346f76 --- /dev/null +++ b/mcproto/types/angle.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import final +import math + +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.abc import MCType, dataclass +from mcproto.types.vec3 import Vec3 + + +@dataclass +@final +class Angle(MCType): + """Represents a rotation angle for an entity. + + :param value: The angle value in 1/256th of a full rotation. + """ + + angle: int + + @override + def serialize_to(self, buf: Buffer) -> None: + payload = int(self.angle) & 0xFF + # Convert to a signed byte. + if payload & 0x80: + payload -= 1 << 8 + buf.write_value(StructFormat.BYTE, payload) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Angle: + payload = buf.read_value(StructFormat.BYTE) + return cls(angle=int(payload * 360 / 256)) + + @override + def validate(self) -> None: + """Constrain the angle to the range [0, 256).""" + self.angle %= 256 + + def in_direction(self, base: Vec3, distance: float) -> Vec3: + """Calculate the position in the direction of the angle in the xz-plane. + + 0/256: Positive z-axis + 64/-192: Negative x-axis + 128/-128: Negative z-axis + 192/-64: Positive x-axis + + :param base: The base position. + :param distance: The distance to move. + :return: The new position. + """ + x = base.x - distance * math.sin(self.to_radians()) + z = base.z + distance * math.cos(self.to_radians()) + return Vec3(x=x, y=base.y, z=z) + + @classmethod + def from_degrees(cls, degrees: float) -> Angle: + """Create an angle from degrees.""" + return cls(angle=int(degrees * 256 / 360)) + + def to_degrees(self) -> float: + """Return the angle in degrees.""" + return self.angle * 360 / 256 + + @classmethod + def from_radians(cls, radians: float) -> Angle: + """Create an angle from radians.""" + return cls(angle=int(math.degrees(radians) * 256 / 360)) + + def to_radians(self) -> float: + """Return the angle in radians.""" + return math.radians(self.angle * 360 / 256) diff --git a/mcproto/types/bitset.py b/mcproto/types/bitset.py new file mode 100644 index 00000000..2b0927d5 --- /dev/null +++ b/mcproto/types/bitset.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import math + +from typing import ClassVar +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.abc import MCType, dataclass + + +@dataclass +class FixedBitset(MCType): + """Represents a fixed-size bitset.""" + + __n: ClassVar[int] = -1 + + data: bytearray + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write(bytes(self.data)) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> FixedBitset: + data = buf.read(math.ceil(cls.__n / 8)) + return cls(data=data) + + @override + def validate(self) -> None: + """Validate the bitset.""" + if self.__n == -1: + raise ValueError("Bitset size is not defined.") + if len(self.data) != math.ceil(self.__n / 8): + raise ValueError(f"Bitset size is {len(self.data) * 8}, expected {self.__n}.") + + @staticmethod + def of_size(n: int) -> type[FixedBitset]: + """Return a new FixedBitset class with the given size. + + :param n: The size of the bitset. + """ + new_class = type(f"FixedBitset{n}", (FixedBitset,), {}) + new_class.__n = n + return new_class + + @classmethod + def from_int(cls, n: int) -> FixedBitset: + """Return a new FixedBitset with the given integer value. + + :param n: The integer value. + """ + if cls.__n == -1: + raise ValueError("Bitset size is not defined.") + if n < 0: + # Manually compute two's complement + n = -n + data = bytearray(n.to_bytes(math.ceil(cls.__n / 8), "big")) + for i in range(len(data)): + data[i] ^= 0xFF + data[-1] += 1 + else: + data = bytearray(n.to_bytes(math.ceil(cls.__n / 8), "big")) + return cls(data=data) + + def __setitem__(self, index: int, value: bool) -> None: + byte_index = index // 8 + bit_index = index % 8 + if value: + self.data[byte_index] |= 1 << bit_index + else: + self.data[byte_index] &= ~(1 << bit_index) + + def __getitem__(self, index: int) -> bool: + byte_index = index // 8 + bit_index = index % 8 + return bool(self.data[byte_index] & (1 << bit_index)) + + def __len__(self) -> int: + return self.__n + + def __and__(self, other: FixedBitset) -> FixedBitset: + if self.__n != other.__n: + raise ValueError("Bitsets must have the same size.") + return type(self)(data=bytearray(a & b for a, b in zip(self.data, other.data))) + + def __or__(self, other: FixedBitset) -> FixedBitset: + if self.__n != other.__n: + raise ValueError("Bitsets must have the same size.") + return type(self)(data=bytearray(a | b for a, b in zip(self.data, other.data))) + + def __xor__(self, other: FixedBitset) -> FixedBitset: + if self.__n != other.__n: + raise ValueError("Bitsets must have the same size.") + return type(self)(data=bytearray(a ^ b for a, b in zip(self.data, other.data))) + + def __invert__(self) -> FixedBitset: + return type(self)(data=bytearray(~a & 0xFF for a in self.data)) + + def __bytes__(self) -> bytes: + return bytes(self.data) + + @override + def __eq__(self, value: object) -> bool: + if not isinstance(value, FixedBitset): + return NotImplemented + return self.data == value.data and self.__n == value.__n + + +@dataclass +class Bitset(MCType): + """Represents a lenght-prefixed bitset with a variable size. + + :param size: The number of longs in the array representing the bitset. + :param data: The bits of the bitset. + """ + + size: int + data: list[int] + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write_varint(self.size) + for i in range(self.size): + buf.write_value(StructFormat.LONGLONG, self.data[i]) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Bitset: + size = buf.read_varint() + if buf.remaining < size * 8: + raise IOError("Not enough data to read bitset.") + data = [buf.read_value(StructFormat.LONGLONG) for _ in range(size)] + return cls(size=size, data=data) + + @override + def validate(self) -> None: + """Validate the bitset.""" + if self.size != len(self.data): + raise ValueError(f"Bitset size is {self.size}, expected {len(self.data)}.") + + @classmethod + def from_int(cls, n: int, size: int | None = None) -> Bitset: + """Return a new Bitset with the given integer value. + + :param n: The integer value. + :param size: The number of longs in the array representing the bitset. + """ + if size is None: + size = math.ceil(float(n.bit_length()) / 64.0) + data = [n >> (i * 64) & 0xFFFFFFFFFFFFFFFF for i in range(size)] + return cls(size=size, data=data) + + def __getitem__(self, index: int) -> bool: + byte_index = index // 64 + bit_index = index % 64 + + return bool(self.data[byte_index] & (1 << bit_index)) + + def __setitem__(self, index: int, value: bool) -> None: + byte_index = index // 64 + bit_index = index % 64 + + if value: + self.data[byte_index] |= 1 << bit_index + else: + self.data[byte_index] &= ~(1 << bit_index) + + def __len__(self) -> int: + return self.size * 64 + + def __and__(self, other: Bitset) -> Bitset: + if self.size != other.size: + raise ValueError("Bitsets must have the same size.") + return Bitset(size=self.size, data=[a & b for a, b in zip(self.data, other.data)]) + + def __or__(self, other: Bitset) -> Bitset: + if self.size != other.size: + raise ValueError("Bitsets must have the same size.") + return Bitset(size=self.size, data=[a | b for a, b in zip(self.data, other.data)]) + + def __xor__(self, other: Bitset) -> Bitset: + if self.size != other.size: + raise ValueError("Bitsets must have the same size.") + return Bitset(size=self.size, data=[a ^ b for a, b in zip(self.data, other.data)]) + + def __invert__(self) -> Bitset: + return Bitset(size=self.size, data=[~a for a in self.data]) + + def __bytes__(self) -> bytes: + return b"".join(a.to_bytes(8, "big") for a in self.data) diff --git a/mcproto/types/chat.py b/mcproto/types/chat.py index fe978631..e674bb53 100644 --- a/mcproto/types/chat.py +++ b/mcproto/types/chat.py @@ -1,26 +1,27 @@ from __future__ import annotations import json -from typing import TypedDict, Union, final +from typing import Tuple, TypedDict, Union, cast, final from typing_extensions import Self, TypeAlias, override from mcproto.buffer import Buffer from mcproto.types.abc import MCType, dataclass +from mcproto.types.nbt import NBTag, StringNBT, ByteNBT, FromObjectSchema, FromObjectType __all__ = [ - "ChatMessage", - "RawChatMessage", - "RawChatMessageDict", + "TextComponent", + "RawTextComponentDict", + "RawTextComponent", ] -class RawChatMessageDict(TypedDict, total=False): +class RawTextComponentDict(TypedDict, total=False): """Dictionary structure of JSON chat messages when serialized.""" text: str translation: str - extra: list[RawChatMessageDict] + extra: list[RawTextComponentDict] color: str bold: bool @@ -30,22 +31,27 @@ class RawChatMessageDict(TypedDict, total=False): obfuscated: bool -RawChatMessage: TypeAlias = Union[RawChatMessageDict, "list[RawChatMessageDict]", str] +RawTextComponent: TypeAlias = Union[RawTextComponentDict, "list[RawTextComponentDict]", str] + + +def _deep_copy_dict(data: RawTextComponentDict) -> RawTextComponentDict: + """Deep copy a dictionary structure.""" + json_data = json.dumps(data) + return json.loads(json_data) @dataclass -@final -class ChatMessage(MCType): +class JSONTextComponent(MCType): """Minecraft chat message representation.""" - raw: RawChatMessage + raw: RawTextComponent - def as_dict(self) -> RawChatMessageDict: + def as_dict(self) -> RawTextComponentDict: """Convert received ``raw`` into a stadard :class:`dict` form.""" if isinstance(self.raw, list): - return RawChatMessageDict(extra=self.raw) + return RawTextComponentDict(extra=self.raw) if isinstance(self.raw, str): - return RawChatMessageDict(text=self.raw) + return RawTextComponentDict(text=self.raw) if isinstance(self.raw, dict): # pyright: ignore[reportUnnecessaryIsInstance] return self.raw @@ -61,7 +67,7 @@ def __eq__(self, other: object) -> bool: a chat message that appears the same, but was representing in a different way will fail this equality check. """ - if not isinstance(other, ChatMessage): + if not isinstance(other, JSONTextComponent): return NotImplemented return self.raw == other.raw @@ -93,3 +99,90 @@ def validate(self) -> None: raise AttributeError( "Expected each element in `raw` to have either 'text' or 'extra' key, got neither" ) + + +@final +class TextComponent(JSONTextComponent): + """Minecraft chat message representation. + + This class provides the new chat message format using NBT data instead of JSON. + """ + + __slots__ = () + + @override + def serialize_to(self, buf: Buffer) -> None: + payload = self._convert_to_dict(self.raw) + payload = cast(FromObjectType, payload) # We just ensured that the data is converted to the correct format + nbt = NBTag.from_object(data=payload, schema=self._build_schema()) # This will validate the data + nbt.serialize_to(buf) + + @override + @classmethod + def deserialize(cls, buf: Buffer, /) -> Self: + nbt = NBTag.deserialize(buf, with_name=False) + # Ensure the schema is compatible with the one defined in the class + data, schema = cast(Tuple[FromObjectType, FromObjectSchema], nbt.to_object(include_schema=True)) + + def recursive_validate(recieved: FromObjectSchema, expected: FromObjectSchema) -> None: + if isinstance(recieved, dict): + if not isinstance(expected, dict): + raise TypeError(f"Expected {expected!r}, got dict") + for key, value in recieved.items(): + if key not in expected: + raise KeyError(f"Unexpected key {key!r}") + recursive_validate(value, expected[key]) + elif isinstance(recieved, list): + if not isinstance(expected, list): + raise TypeError(f"Expected {expected!r}, got list") + for rec in recieved: + recursive_validate(rec, expected[0]) + elif recieved != expected: + raise TypeError(f"Expected {expected!r}, got {recieved!r}") + + recursive_validate(schema, cls._build_schema()) + data = cast(RawTextComponentDict, data) # We just ensured that the data is compatible with the schema + return cls(data) + + @staticmethod + def _build_schema() -> FromObjectSchema: + """Build the schema for the NBT data representing the chat message.""" + schema: FromObjectSchema = { + "text": StringNBT, + "color": StringNBT, + "bold": ByteNBT, + "italic": ByteNBT, + "underlined": ByteNBT, + "strikethrough": ByteNBT, + "obfuscated": ByteNBT, + } + # Allow the schema to be recursive + schema["extra"] = [schema] # type: ignore + return schema + + @staticmethod + def _convert_to_dict(msg: RawTextComponent) -> RawTextComponentDict: + """Convert a chat message into a dictionary representation.""" + if isinstance(msg, str): + return {"text": msg} + + if isinstance(msg, list): + main = TextComponent._convert_to_dict(msg[0]) + if "extra" not in main: + main["extra"] = [] + for elem in msg[1:]: + main["extra"].append(TextComponent._convert_to_dict(elem)) + return main + + if isinstance(msg, dict): # pyright: ignore[reportUnnecessaryIsInstance] + return _deep_copy_dict(msg) # We don't want to modify self.raw for example + + raise TypeError(f"Unexpected type {msg!r} ({msg.__class__.__name__})") # pragma: no cover + + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, TextComponent): + return NotImplemented + self_dict = self._convert_to_dict(self.raw) + other_dict = self._convert_to_dict(other.raw) + return self_dict == other_dict diff --git a/mcproto/types/identifier.py b/mcproto/types/identifier.py new file mode 100644 index 00000000..cdd8deca --- /dev/null +++ b/mcproto/types/identifier.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import re +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.types.abc import MCType, dataclass + + +@dataclass +class Identifier(MCType): + """A Minecraft identifier. + + :param namespace: The namespace of the identifier. + :param path: The path of the identifier. + """ + + namespace: str + path: str + + @override + def serialize_to(self, buf: Buffer) -> None: + data = f"{self.namespace}:{self.path}" + buf.write_utf(data) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Identifier: + data = buf.read_utf() + namespace, path = data.split(":", 1) + return cls(namespace, path) + + @override + def validate(self) -> None: + if len(self.namespace) == 0: + raise ValueError("Namespace cannot be empty.") + + if len(self.path) == 0: + raise ValueError("Path cannot be empty.") + + if len(self.namespace) + len(self.path) + 1 > 32767: + raise ValueError("Identifier is too long.") + + namespace_regex = r"^[a-z0-9-_]+$" + path_regex = r"^[a-z0-9-_/]+$" + + if not re.match(namespace_regex, self.namespace): + raise ValueError(f"Namespace must match regex {namespace_regex}, got {self.namespace!r}") + + if not re.match(path_regex, self.path): + raise ValueError(f"Path must match regex {path_regex}, got {self.path!r}") diff --git a/mcproto/types/nbt.py b/mcproto/types/nbt.py index 6a230a04..4fc0a04b 100644 --- a/mcproto/types/nbt.py +++ b/mcproto/types/nbt.py @@ -2,14 +2,14 @@ from abc import abstractmethod from enum import IntEnum -from typing import Union, cast, Protocol, runtime_checkable +from typing import ClassVar, Union, cast, Protocol, final, runtime_checkable from collections.abc import Iterator, Mapping, Sequence from typing_extensions import TypeAlias, override from mcproto.buffer import Buffer -from mcproto.protocol.base_io import StructFormat -from mcproto.types.abc import MCType +from mcproto.protocol.base_io import StructFormat, INT_FORMATS_TYPE, FLOAT_FORMATS_TYPE +from mcproto.types.abc import MCType, dataclass __all__ = [ "ByteArrayNBT", @@ -187,10 +187,6 @@ class NBTag(MCType, NBTagConvertible): __slots__ = ("name", "payload") - def __init__(self, payload: PayloadType, name: str = ""): - self.name = name - self.payload = payload - @override def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: """Serialize the NBT tag to a new buffer. @@ -359,7 +355,7 @@ def from_object(data: FromObjectType, schema: FromObjectSchema, name: str = "") raise TypeError("Expected a list of integers, but a non-integer element was found.") data = cast(Union[bytes, str, int, float, "list[int]"], data) # Create the tag with the data and the name - return schema(data, name=name) + return schema(data, name=name) # type: ignore # The schema is a subclass of NBTag # Sanity check : Verify that all type schemas have been handled if not isinstance(schema, (list, tuple, dict)): @@ -371,7 +367,7 @@ def from_object(data: FromObjectType, schema: FromObjectSchema, name: str = "") if isinstance(schema, dict): # We can unpack the dictionary and create a CompoundNBT tag if not isinstance(data, dict): - raise TypeError(f"Expected a dictionary, but found {type(data).__name__}.") + raise TypeError(f"Expected a dictionary, but found a different type ({type(data).__name__}).") # Iterate over the dictionary payload: list[NBTag] = [] for key, value in data.items(): @@ -480,14 +476,13 @@ def value(self) -> PayloadType: # region NBT tags types +@final +@dataclass class EndNBT(NBTag): """Sentinel tag used to mark the end of a TAG_Compound.""" - __slots__ = () - - def __init__(self): - """Create a new EndNBT tag.""" - super().__init__(0, name="") + payload: None = None + name: str = "" @override def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = False) -> None: @@ -513,151 +508,118 @@ def value(self) -> PayloadType: return NotImplemented -class ByteNBT(NBTag): - """NBT tag representing a single byte value, represented as a signed 8-bit integer.""" +@dataclass +class _NumberNBTag(NBTag): + """Base class for NBT tags representing a number. + + This class is not meant to be used directly, but rather through its subclasses. + """ + + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = NotImplemented # type: ignore + DATA_SIZE: ClassVar[int] = NotImplemented - __slots__ = () payload: int + name: str = "" @override def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) - if self.payload < -(1 << 7) or self.payload >= 1 << 7: - raise OverflowError("Byte value out of range.") - - buf.write_value(StructFormat.BYTE, self.payload) + buf.write_value(self.STRUCT_FORMAT, self.payload) @override @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteNBT: + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> _NumberNBTag: name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") - if buf.remaining < 1: - raise IOError("Buffer does not contain enough data to read a byte. (Empty buffer)") + if buf.remaining < cls.DATA_SIZE: + raise IOError(f"Buffer does not contain enough data to read a {tag_type.name}.") - return ByteNBT(buf.read_value(StructFormat.BYTE), name=name) + return cls(buf.read_value(cls.STRUCT_FORMAT), name=name) - def __int__(self) -> int: - """Get the integer value of the ByteNBT tag.""" - return self.payload + @override + def validate(self) -> None: + if not isinstance(self.payload, int): # type: ignore + raise TypeError(f"Expected an int, but found {type(self.payload).__name__}.") + int_min = -(1 << (self.DATA_SIZE * 8 - 1)) + int_max = (1 << (self.DATA_SIZE * 8 - 1)) - 1 + if not int_min <= self.payload <= int_max: + raise OverflowError(f"Value out of range for a {type(self).__name__} tag.") @property @override def value(self) -> int: return self.payload + def __int__(self) -> int: + return self.payload -class ShortNBT(ByteNBT): - """NBT tag representing a short value, represented as a signed 16-bit integer.""" - - __slots__ = () - - @override - def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) - - if self.payload < -(1 << 15) or self.payload >= 1 << 15: - raise OverflowError("Short value out of range.") - - buf.write_value(StructFormat.SHORT, self.payload) - - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ShortNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if _get_tag_type(cls) != tag_type: - raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") - - if buf.remaining < 2: - raise IOError("Buffer does not contain enough data to read a short.") - - return ShortNBT(buf.read_value(StructFormat.SHORT), name=name) +class ByteNBT(_NumberNBTag): + """NBT tag representing a single byte value, represented as a signed 8-bit integer.""" -class IntNBT(ByteNBT): - """NBT tag representing an integer value, represented as a signed 32-bit integer.""" + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.BYTE + DATA_SIZE: ClassVar[int] = 1 __slots__ = () - @override - def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) - - if self.payload < -(1 << 31) or self.payload >= 1 << 31: - raise OverflowError("Integer value out of range.") - # No more messing around with the struct, we want 32 bits of data no matter what - buf.write_value(StructFormat.INT, self.payload) +class ShortNBT(_NumberNBTag): + """NBT tag representing a short value, represented as a signed 16-bit integer.""" - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if _get_tag_type(cls) != tag_type: - raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.SHORT + DATA_SIZE: ClassVar[int] = 2 - if buf.remaining < 4: - raise IOError("Buffer does not contain enough data to read an int.") + __slots__ = () - return IntNBT(buf.read_value(StructFormat.INT), name=name) +class IntNBT(_NumberNBTag): + """NBT tag representing an integer value, represented as a signed 32-bit integer.""" -class LongNBT(ByteNBT): - """NBT tag representing a long value, represented as a signed 64-bit integer.""" + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.INT + DATA_SIZE: ClassVar[int] = 4 __slots__ = () - @override - def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) - if self.payload < -(1 << 63) or self.payload >= 1 << 63: - raise OverflowError("Long value out of range.") +class LongNBT(_NumberNBTag): + """NBT tag representing a long value, represented as a signed 64-bit integer.""" - # No more messing around with the struct, we want 64 bits of data no matter what - buf.write_value(StructFormat.LONGLONG, self.payload) + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.LONGLONG + DATA_SIZE: ClassVar[int] = 8 - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if _get_tag_type(cls) != tag_type: - raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + __slots__ = () - if buf.remaining < 8: - raise IOError("Buffer does not contain enough data to read a long.") - return LongNBT(buf.read_value(StructFormat.LONGLONG), name=name) +@dataclass +class _FloatingNBTag(NBTag): + """Base class for NBT tags representing a floating-point number.""" - -class FloatNBT(NBTag): - """NBT tag representing a floating-point value, represented as a 32-bit IEEE 754-2008 binary32 value.""" + STRUCT_FORMAT: ClassVar[FLOAT_FORMATS_TYPE] = NotImplemented # type: ignore + DATA_SIZE: ClassVar[int] = NotImplemented payload: float - - __slots__ = () + name: str = "" @override def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) - buf.write_value(StructFormat.FLOAT, self.payload) + buf.write_value(self.STRUCT_FORMAT, self.payload) @override @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> FloatNBT: + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> _FloatingNBTag: name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") - if buf.remaining < 4: - raise IOError("Buffer does not contain enough data to read a float.") + if buf.remaining < cls.DATA_SIZE: + raise IOError(f"Buffer does not contain enough data to read a {tag_type.name}.") - return FloatNBT(buf.read_value(StructFormat.FLOAT), name=name) + return cls(buf.read_value(cls.STRUCT_FORMAT), name=name) def __float__(self) -> float: - """Get the float value of the FloatNBT tag.""" return self.payload @property @@ -665,36 +627,40 @@ def __float__(self) -> float: def value(self) -> float: return self.payload + @override + def validate(self) -> None: + if isinstance(self.payload, int): + self.payload = float(self.payload) + if not isinstance(self.payload, float): + raise TypeError(f"Expected a float, but found {type(self.payload).__name__}.") -class DoubleNBT(FloatNBT): - """NBT tag representing a double-precision floating-point value, represented as a 64-bit IEEE 754-2008 binary64.""" + +@final +class FloatNBT(_FloatingNBTag): + """NBT tag representing a floating-point value, represented as a 32-bit IEEE 754-2008 binary32 value.""" + + STRUCT_FORMAT: ClassVar[FLOAT_FORMATS_TYPE] = StructFormat.FLOAT + DATA_SIZE: ClassVar[int] = 4 __slots__ = () - @override - def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) - buf.write_value(StructFormat.DOUBLE, self.payload) - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> DoubleNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if _get_tag_type(cls) != tag_type: - raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") +@final +class DoubleNBT(_FloatingNBTag): + """NBT tag representing a double-precision floating-point value, represented as a 64-bit IEEE 754-2008 binary64.""" - if buf.remaining < 8: - raise IOError("Buffer does not contain enough data to read a double.") + STRUCT_FORMAT: ClassVar[FLOAT_FORMATS_TYPE] = StructFormat.DOUBLE + DATA_SIZE: ClassVar[int] = 8 - return DoubleNBT(buf.read_value(StructFormat.DOUBLE), name=name) + __slots__ = () +@dataclass class ByteArrayNBT(NBTag): """NBT tag representing an array of bytes. The length of the array is stored as a signed 32-bit integer.""" - __slots__ = () - payload: bytes + name: str = "" @override def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: @@ -740,13 +706,20 @@ def __repr__(self) -> str: def value(self) -> bytes: return self.payload + @override + def validate(self) -> None: + if isinstance(self.payload, bytearray): + self.payload = bytes(self.payload) + if not isinstance(self.payload, bytes): + raise TypeError(f"Expected a bytes, but found {type(self.payload).__name__}.") + +@dataclass class StringNBT(NBTag): """NBT tag representing an UTF-8 string value. The length of the string is stored as a signed 16-bit integer.""" - __slots__ = () - payload: str + name: str = "" @override def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: @@ -787,13 +760,25 @@ def __str__(self) -> str: def value(self) -> str: return self.payload + @override + def validate(self) -> None: + if not isinstance(self.payload, str): # type: ignore + raise TypeError(f"Expected a str, but found {type(self.payload).__name__}.") + if len(self.payload) > 32767: + raise ValueError("Maximum character limit for writing strings is 32767 characters.") + # Check that the string is valid UTF-8 + try: + self.payload.encode("utf-8") + except UnicodeEncodeError as exc: # pragma: no cover (don't know how to trigger) + raise ValueError("Invalid UTF-8 string.") from exc + +@dataclass class ListNBT(NBTag): """NBT tag representing a list of tags. All tags in the list must be of the same type.""" - __slots__ = () - payload: list[NBTag] + name: str = "" @override def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: @@ -902,13 +887,27 @@ def to_object( def value(self) -> list[PayloadType]: return [tag.value for tag in self.payload] + @override + def validate(self) -> None: + if not isinstance(self.payload, list): # type: ignore + raise TypeError(f"Expected a list, but found {type(self.payload).__name__}.") + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway + raise TypeError("All items in a list must be NBTags.") + if not self.payload: + return + first_tag_type = type(self.payload[0]) + if not all(type(tag) is first_tag_type for tag in self.payload): + raise TypeError("All tags in a list must be of the same type.") + if not all(tag.name == "" for tag in self.payload): + raise ValueError("All tags in a list must be unnamed.") + +@dataclass class CompoundNBT(NBTag): """NBT tag representing a compound of named tags.""" - __slots__ = () - payload: list[NBTag] + name: str = "" @override def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: @@ -1014,95 +1013,86 @@ def __eq__(self, other: object) -> bool: def value(self) -> dict[str, PayloadType]: return {tag.name: tag.value for tag in self.payload} + @override + def validate(self) -> None: + if not isinstance(self.payload, list): # type: ignore + raise TypeError(f"Expected a list, but found {type(self.payload).__name__}.") + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore + raise TypeError("All items in a compound must be NBTags.") + if not all(tag.name for tag in self.payload): + raise ValueError("All tags in a compound must be named.") + if len(self.payload) != len({tag.name for tag in self.payload}): + raise ValueError("All tags in a compound must have unique names.") -class IntArrayNBT(NBTag): - """NBT tag representing an array of integers. The length of the array is stored as a signed 32-bit integer.""" - __slots__ = () +@dataclass +class _NumberArrayNBTag(NBTag): + """Base class for NBT tags representing an array of numbers.""" + + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = NotImplemented # type: ignore + DATA_SIZE: ClassVar[int] = NotImplemented payload: list[int] + name: str = "" @override def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: self._write_header(buf, with_type=with_type, with_name=with_name) - - if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway - raise ValueError("All items in an integer array must be integers.") - - if any(item < -(1 << 31) or item >= 1 << 31 for item in self.payload): - raise OverflowError("Integer array contains values out of range.") - IntNBT(len(self.payload)).serialize_to(buf, with_name=False, with_type=False) for i in self.payload: - IntNBT(i).serialize_to(buf, with_name=False, with_type=False) + buf.write_value(self.STRUCT_FORMAT, i) @override @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntArrayNBT: + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> _NumberArrayNBTag: name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != NBTagType.INT_ARRAY: - raise TypeError(f"Expected an INT_ARRAY tag, but found a different tag ({tag_type}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") length = IntNBT.read_from(buf, with_type=False, with_name=False).value - try: - payload = [IntNBT.read_from(buf, with_type is NBTagType.INT, with_name=False).value for _ in range(length)] - except IOError as exc: - raise IOError( - "Buffer does not contain enough data to read the entire integer array. (Incomplete data)" - ) from exc - return IntArrayNBT(payload, name=name) - @override - def __repr__(self) -> str: - if self.name: - return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" - if len(self.payload) < 8: - return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" - return f"{type(self).__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + if buf.remaining < length * cls.DATA_SIZE: + raise IOError(f"Buffer does not contain enough data to read the entire {tag_type.name}.") - def __iter__(self) -> Iterator[int]: - """Iterate over the integers in the array.""" - yield from self.payload + return cls([buf.read_value(cls.STRUCT_FORMAT) for _ in range(length)], name=name) + + @override + def validate(self) -> None: + if not isinstance(self.payload, list): # type: ignore + raise TypeError(f"Expected a list, but found {type(self.payload).__name__}.") + if not all(isinstance(item, int) for item in self.payload): # type: ignore + raise TypeError("All items in an integer array must be integers.") + if any( + item < -(1 << (self.DATA_SIZE * 8 - 1)) or item >= 1 << (self.DATA_SIZE * 8 - 1) for item in self.payload + ): + raise OverflowError(f"Integer array contains values out of range. ({self.payload})") @property @override def value(self) -> list[int]: return self.payload + def __iter__(self) -> Iterator[int]: + yield from self.payload -class LongArrayNBT(IntArrayNBT): - """NBT tag representing an array of longs. The length of the array is stored as a signed 32-bit integer.""" - __slots__ = () +@final +class IntArrayNBT(_NumberArrayNBTag): + """NBT tag representing an array of integers. The length of the array is stored as a signed 32-bit integer.""" - @override - def serialize_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - self._write_header(buf, with_type=with_type, with_name=with_name) + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.INT + DATA_SIZE: ClassVar[int] = 4 - if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway - raise ValueError(f"All items in a long array must be integers. ({self.payload})") + __slots__ = () - if any(item < -(1 << 63) or item >= 1 << 63 for item in self.payload): - raise OverflowError(f"Long array contains values out of range. ({self.payload})") - IntNBT(len(self.payload)).serialize_to(buf, with_name=False, with_type=False) - for i in self.payload: - LongNBT(i).serialize_to(buf, with_name=False, with_type=False) +@final +class LongArrayNBT(_NumberArrayNBTag): + """NBT tag representing an array of longs. The length of the array is stored as a signed 32-bit integer.""" - @override - @classmethod - def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongArrayNBT: - name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != NBTagType.LONG_ARRAY: - raise TypeError(f"Expected a LONG_ARRAY tag, but found a different tag ({tag_type}).") - length = IntNBT.read_from(buf, with_type=False, with_name=False).payload + STRUCT_FORMAT: ClassVar[INT_FORMATS_TYPE] = StructFormat.LONGLONG + DATA_SIZE: ClassVar[int] = 8 - try: - payload = [LongNBT.read_from(buf, with_type=False, with_name=False).payload for _ in range(length)] - except IOError as exc: - raise IOError( - "Buffer does not contain enough data to read the entire long array. (Incomplete data)" - ) from exc - return LongArrayNBT(payload, name=name) + __slots__ = () # endregion diff --git a/mcproto/types/quaternion.py b/mcproto/types/quaternion.py new file mode 100644 index 00000000..c2f85ae7 --- /dev/null +++ b/mcproto/types/quaternion.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import math + +from typing import final +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.abc import MCType, dataclass + + +@dataclass +@final +class Quaternion(MCType): + """Represents a quaternion. + + :param x: The x component. + :param y: The y component. + :param z: The z component. + :param w: The w component. + """ + + x: float | int + y: float | int + z: float | int + w: float | int + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write_value(StructFormat.FLOAT, self.x) + buf.write_value(StructFormat.FLOAT, self.y) + buf.write_value(StructFormat.FLOAT, self.z) + buf.write_value(StructFormat.FLOAT, self.w) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Quaternion: + x = buf.read_value(StructFormat.FLOAT) + y = buf.read_value(StructFormat.FLOAT) + z = buf.read_value(StructFormat.FLOAT) + w = buf.read_value(StructFormat.FLOAT) + return cls(x=x, y=y, z=z, w=w) + + @override + def validate(self) -> None: + """Validate the quaternion's components.""" + # Check that the components are floats or integers. + if not all(isinstance(comp, (float, int)) for comp in (self.x, self.y, self.z, self.w)): # type: ignore + raise TypeError( + f"Quaternion components must be floats or integers, got {self.x!r}, {self.y!r}, {self.z!r}, {self.w!r}" + ) + + # Check that the components are not NaN. + if any(not math.isfinite(comp) for comp in (self.x, self.y, self.z, self.w)): + raise ValueError( + f"Quaternion components must not be NaN, got {self.x!r}, {self.y!r}, {self.z!r}, {self.w!r}." + ) + + def __add__(self, other: Quaternion) -> Quaternion: + # Use the type of self to return a Quaternion or a subclass. + return type(self)(x=self.x + other.x, y=self.y + other.y, z=self.z + other.z, w=self.w + other.w) + + def __sub__(self, other: Quaternion) -> Quaternion: + return type(self)(x=self.x - other.x, y=self.y - other.y, z=self.z - other.z, w=self.w - other.w) + + def __neg__(self) -> Quaternion: + return type(self)(x=-self.x, y=-self.y, z=-self.z, w=-self.w) + + def __mul__(self, other: float) -> Quaternion: + return type(self)(x=self.x * other, y=self.y * other, z=self.z * other, w=self.w * other) + + def __truediv__(self, other: float) -> Quaternion: + return type(self)(x=self.x / other, y=self.y / other, z=self.z / other, w=self.w / other) + + def to_tuple(self) -> tuple[float, float, float, float]: + """Convert the quaternion to a tuple.""" + return (self.x, self.y, self.z, self.w) + + def norm_squared(self) -> float: + """Return the squared norm of the quaternion.""" + return self.x**2 + self.y**2 + self.z**2 + self.w**2 + + def norm(self) -> float: + """Return the norm of the quaternion.""" + return math.sqrt(self.norm_squared()) + + def normalize(self) -> Quaternion: + """Return the normalized quaternion.""" + norm = self.norm() + return Quaternion(x=self.x / norm, y=self.y / norm, z=self.z / norm, w=self.w / norm) diff --git a/mcproto/types/slot.py b/mcproto/types/slot.py new file mode 100644 index 00000000..d8e21d48 --- /dev/null +++ b/mcproto/types/slot.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import cast, final + +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.nbt import CompoundNBT, EndNBT, NBTag +from mcproto.types.abc import MCType, dataclass + +__all__ = ["Slot"] + +""" +https://wiki.vg/Slot_Data +""" + + +@dataclass +@final +class Slot(MCType): + """Represents a slot in an inventory. + + :param present: Whether the slot has an item in it. + :param item_id: (optional) The item ID of the item in the slot. + :param item_count: (optional) The count of items in the slot. + :param nbt: (optional) The NBT data of the item in the slot. + + The optional parameters are present if and only if the slot is present. + """ + + present: bool + item_id: int | None = None + item_count: int | None = None + nbt: NBTag | None = None + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write_value(StructFormat.BOOL, self.present) + if self.present: + self.item_id = cast(int, self.item_id) + self.item_count = cast(int, self.item_count) + self.nbt = cast(NBTag, self.nbt) + buf.write_varint(self.item_id) + buf.write_value(StructFormat.BYTE, self.item_count) + self.nbt.serialize_to(buf, with_name=False) # In 1.20.2 and later, the NBT is not named, there is only the + # type (TAG_End or TAG_Compound) and the payload. + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Slot: + present = buf.read_value(StructFormat.BOOL) + if not present: + return cls(present=False) + item_id = buf.read_varint() + item_count = buf.read_value(StructFormat.BYTE) + nbt = NBTag.deserialize(buf, with_name=False) + return cls(present=True, item_id=item_id, item_count=item_count, nbt=nbt) + + @override + def validate(self) -> None: + # If the slot is present, all the fields must be present. + if self.present: + if self.item_id is None: + raise ValueError("Item ID is missing.") + if self.item_count is None: + raise ValueError("Item count is missing.") + if self.nbt is None: + self.nbt = EndNBT() + elif not isinstance(self.nbt, (CompoundNBT, EndNBT)): + raise TypeError("NBT data associated with a slot must be in a CompoundNBT.") + else: + if self.item_id is not None: + raise ValueError("Item ID must be None if there is no item in the slot.") + if self.item_count is not None: + raise ValueError("Item count must be None if there is no item in the slot.") + if self.nbt is not None and not isinstance(self.nbt, EndNBT): + raise ValueError("NBT data must be None if there is no item in the slot.") diff --git a/mcproto/types/vec3.py b/mcproto/types/vec3.py new file mode 100644 index 00000000..1cc8fcbe --- /dev/null +++ b/mcproto/types/vec3.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import math + +from typing import cast, final +from typing_extensions import override + +from mcproto.buffer import Buffer +from mcproto.protocol import StructFormat +from mcproto.types.abc import MCType, dataclass + + +@dataclass +class Vec3(MCType): + """Represents a 3D vector. + + :param x: The x component. + :param y: The y component. + :param z: The z component. + """ + + x: float | int + y: float | int + z: float | int + + @override + def serialize_to(self, buf: Buffer) -> None: + buf.write_value(StructFormat.FLOAT, self.x) + buf.write_value(StructFormat.FLOAT, self.y) + buf.write_value(StructFormat.FLOAT, self.z) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Vec3: + x = buf.read_value(StructFormat.FLOAT) + y = buf.read_value(StructFormat.FLOAT) + z = buf.read_value(StructFormat.FLOAT) + return cls(x=x, y=y, z=z) + + @override + def validate(self) -> None: + """Validate the vector's components.""" + # Check that the components are floats or integers. + if not all(isinstance(comp, (float, int)) for comp in (self.x, self.y, self.z)): # type: ignore + raise TypeError(f"Vector components must be floats or integers, got {self.x!r}, {self.y!r}, {self.z!r}") + + # Check that the components are not NaN. + if any(not math.isfinite(comp) for comp in (self.x, self.y, self.z)): + raise ValueError(f"Vector components must not be NaN, got {self.x!r}, {self.y!r}, {self.z!r}.") + + def __add__(self, other: Vec3) -> Vec3: + # Use the type of self to return a Vec3 or a subclass. + return type(self)(x=self.x + other.x, y=self.y + other.y, z=self.z + other.z) + + def __sub__(self, other: Vec3) -> Vec3: + return type(self)(x=self.x - other.x, y=self.y - other.y, z=self.z - other.z) + + def __neg__(self) -> Vec3: + return type(self)(x=-self.x, y=-self.y, z=-self.z) + + def __mul__(self, other: float) -> Vec3: + return type(self)(x=self.x * other, y=self.y * other, z=self.z * other) + + def __truediv__(self, other: float) -> Vec3: + return type(self)(x=self.x / other, y=self.y / other, z=self.z / other) + + def to_position(self) -> Position: + """Convert the vector to a position.""" + return Position(x=int(self.x), y=int(self.y), z=int(self.z)) + + def to_tuple(self) -> tuple[float, float, float]: + """Convert the vector to a tuple.""" + return (self.x, self.y, self.z) + + def to_vec3(self) -> Vec3: + """Convert the vector to a Vec3. + + This function creates a new Vec3 object with the same components. + """ + return Vec3(x=self.x, y=self.y, z=self.z) + + def norm_squared(self) -> float: + """Return the squared norm of the vector.""" + return self.x**2 + self.y**2 + self.z**2 + + def norm(self) -> float: + """Return the norm of the vector.""" + return math.sqrt(self.norm_squared()) + + def normalize(self) -> Vec3: + """Return the normalized vector.""" + norm = self.norm() + return Vec3(x=self.x / norm, y=self.y / norm, z=self.z / norm) + + +@final +class Position(Vec3): + """Represents a position in the world. + + :param x: The x coordinate (26 bits). + :param y: The y coordinate (12 bits). + :param z: The z coordinate (26 bits). + """ + + __slots__ = () + + @override + def serialize_to(self, buf: Buffer) -> None: + self.x = cast(int, self.x) + self.y = cast(int, self.y) + self.z = cast(int, self.z) + encoded = ((self.x & 0x3FFFFFF) << 38) | ((self.z & 0x3FFFFFF) << 12) | (self.y & 0xFFF) + + # Convert the bit mess to a signed integer for packing. + if encoded & 0x8000000000000000: + encoded -= 1 << 64 + buf.write_value(StructFormat.LONGLONG, encoded) + + @override + @classmethod + def deserialize(cls, buf: Buffer) -> Position: + encoded = buf.read_value(StructFormat.LONGLONG) + x = (encoded >> 38) & 0x3FFFFFF + z = (encoded >> 12) & 0x3FFFFFF + y = encoded & 0xFFF + + # Convert back to signed integers. + if x >= 1 << 25: + x -= 1 << 26 + if y >= 1 << 11: + y -= 1 << 12 + if z >= 1 << 25: + z -= 1 << 26 + + return cls(x=x, y=y, z=z) + + @override + def validate(self) -> None: + """Validate the position's coordinates. + + They are all signed integers, but the x and z coordinates are 26 bits + and the y coordinate is 12 bits. + """ + super().validate() # Validate the Vec3 components. + + self.x = int(self.x) + self.y = int(self.y) + self.z = int(self.z) + if not (-1 << 25 <= self.x < 1 << 25): + raise OverflowError(f"Invalid x coordinate: {self.x}") + if not (-1 << 11 <= self.y < 1 << 11): + raise OverflowError(f"Invalid y coordinate: {self.y}") + if not (-1 << 25 <= self.z < 1 << 25): + raise OverflowError(f"Invalid z coordinate: {self.z}") + + +POS_UP = Position(0, 1, 0) +POS_DOWN = Position(0, -1, 0) +POS_NORTH = Position(0, 0, -1) +POS_SOUTH = Position(0, 0, 1) +POS_EAST = Position(1, 0, 0) +POS_WEST = Position(-1, 0, 0) + +POS_ZERO = Position(0, 0, 0) diff --git a/tests/mcproto/packets/login/test_login.py b/tests/mcproto/packets/login/test_login.py index 71067022..dfcb72bc 100644 --- a/tests/mcproto/packets/login/test_login.py +++ b/tests/mcproto/packets/login/test_login.py @@ -11,7 +11,7 @@ LoginSuccess, ) from mcproto.packets.packet import InvalidPacketContentError -from mcproto.types.chat import ChatMessage +from mcproto.types.chat import JSONTextComponent from mcproto.types.uuid import UUID from tests.helpers import gen_serializable_test from tests.mcproto.test_encryption import RSA_PUBLIC_KEY @@ -90,10 +90,10 @@ def test_login_encryption_request_noid(): gen_serializable_test( context=globals(), cls=LoginDisconnect, - fields=[("reason", ChatMessage)], + fields=[("reason", JSONTextComponent)], test_data=[ ( - (ChatMessage("You are banned."),), + (JSONTextComponent("You are banned."),), bytes.fromhex("1122596f75206172652062616e6e65642e22"), ), ], diff --git a/tests/mcproto/types/test_angle.py b/tests/mcproto/types/test_angle.py new file mode 100644 index 00000000..2de04fd9 --- /dev/null +++ b/tests/mcproto/types/test_angle.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import pytest + +from mcproto.types.vec3 import Position, POS_NORTH, POS_SOUTH, POS_EAST, POS_WEST, POS_ZERO +from mcproto.types.angle import Angle +from tests.helpers import gen_serializable_test + +PI = 3.14159265358979323846 +EPSILON = 1e-6 + +gen_serializable_test( + context=globals(), + cls=Angle, + fields=[("angle", int)], + test_data=[ + ((0,), b"\x00"), + ((256,), b"\x00"), + ((-1,), b"\xff"), + ((-256,), b"\x00"), + ((2,), b"\x02"), + ((-2,), b"\xfe"), + ], +) + + +@pytest.mark.parametrize( + ("angle", "base", "distance", "expected"), + [ + (Angle(0), POS_ZERO, 1, POS_SOUTH), + (Angle(64), POS_ZERO, 1, POS_WEST), + (Angle(128), POS_ZERO, 1, POS_NORTH), + (Angle(192), POS_ZERO, 1, POS_EAST), + ], +) +def test_in_direction(angle: Angle, base: Position, distance: int, expected: Position): + """Test that the in_direction method moves the base position in the correct direction.""" + assert (angle.in_direction(base, distance) - expected).norm() < EPSILON + + +@pytest.mark.parametrize( + ("base2", "degrees"), + [ + (0, 0), + (64, 90), + (128, 180), + (192, 270), + ], +) +def test_degrees(base2: int, degrees: int): + """Test that the from_degrees and to_degrees methods work correctly.""" + assert Angle.from_degrees(degrees) == Angle(base2) + assert Angle(base2).to_degrees() == degrees + + +@pytest.mark.parametrize( + ("rad", "angle"), + [ + (0, 0), + (PI / 2, 64), + (PI, 128), + (3 * PI / 2, 192), + ], +) +def test_radians(rad: float, angle: int): + """Test that the from_radians and to_radians methods work correctly.""" + assert Angle.from_radians(rad) == Angle(angle) + assert abs(Angle(angle).to_radians() - rad) < EPSILON diff --git a/tests/mcproto/types/test_bitset.py b/tests/mcproto/types/test_bitset.py new file mode 100644 index 00000000..0a349a93 --- /dev/null +++ b/tests/mcproto/types/test_bitset.py @@ -0,0 +1,204 @@ +from typing import List + +from mcproto.types.bitset import FixedBitset, Bitset +from tests.helpers import gen_serializable_test + +import pytest + + +gen_serializable_test( + context=globals(), + cls=FixedBitset.of_size(64), + fields=[("data", bytearray)], + test_data=[ + ((bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00"),), b"\x00\x00\x00\x00\x00\x00\x00\x00"), + ((bytearray(b"\xff\xff\xff\xff\xff\xff\xff\xff"),), b"\xff\xff\xff\xff\xff\xff\xff\xff"), + ((bytearray(b"\x55\x55\x55\x55\x55\x55\x55\x55"),), b"\x55\x55\x55\x55\x55\x55\x55\x55"), + ], +) + +gen_serializable_test( + context=globals(), + cls=FixedBitset.of_size(16), + fields=[("data", List[int])], + test_data=[ + ((bytearray(b"\x00"),), ValueError), + ], +) + +gen_serializable_test( + context=globals(), + cls=Bitset, + fields=[("size", int), ("data", List[int])], + test_data=[ + ((1, [1]), b"\x01\x00\x00\x00\x00\x00\x00\x00\x01"), + ( + (2, [1, -1]), + b"\x02\x00\x00\x00\x00\x00\x00\x00\x01\xff\xff\xff\xff\xff\xff\xff\xff", + ), + (IOError, b"\x01"), + ((3, [1]), ValueError), + ], +) + + +def test_fixed_bitset_indexing(): + """Test indexing and setting values in a FixedBitset.""" + b = FixedBitset.of_size(12).from_int(0) + assert b[0] is False + assert b[12] is False + + b[0] = True + assert b[0] is True + assert b[12] is False + + b[12] = True + assert b[12] is True + assert b[0] is True + + b[0] = False + assert b[0] is False + assert b[12] is True + + +def test_bitset_indexing(): + """Test indexing and setting values in a Bitset.""" + b = Bitset.from_int(0, size=2) + assert b[0] is False + assert b[127] is False + + b[0] = True + assert b[0] is True + + b[127] = True + assert b[127] is True + + b[0] = False + assert b[0] is False + + +def test_fixed_bitset_and(): + """Test bitwise AND operation between FixedBitsets.""" + b1 = FixedBitset.of_size(64).from_int(0xFFFFFFFFFFFFFFFF) + b2 = FixedBitset.of_size(64).from_int(0) + + result = b1 & b2 + assert bytes(result) == b"\x00\x00\x00\x00\x00\x00\x00\x00" + + +def test_bitset_and(): + """Test bitwise AND operation between Bitsets.""" + b1 = Bitset(2, [0x0101010101010101, 0x0101010101010100]) + b2 = Bitset(2, [1, 1]) + + result = b1 & b2 + assert result == Bitset(2, [1, 0]) + + +def test_fixed_bitset_or(): + """Test bitwise OR operation between FixedBitsets.""" + b1 = FixedBitset.of_size(8).from_int(0xFE) + b2 = FixedBitset.of_size(8).from_int(0x01) + + result = b1 | b2 + assert bytes(result) == b"\xff" + + +def test_bitset_or(): + """Test bitwise OR operation between Bitsets.""" + b1 = Bitset(2, [0x0101010101010101, 0x0101010101010100]) + b2 = Bitset(2, [1, 1]) + + result = b1 | b2 + assert bytes(result) == b"\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" + + +def test_fixed_bitset_xor(): + """Test bitwise XOR operation between FixedBitsets.""" + b1 = FixedBitset.of_size(64)(bytearray(b"\xff\xff\xff\xff\xff\xff\xff\xff")) + b2 = FixedBitset.of_size(64)(bytearray(b"\x00\x00\x00\x00\x00\x00\x00\x00")) + + result = b1 ^ b2 + assert result == FixedBitset.of_size(64).from_int(-1) + + +def test_bitset_xor(): + """Test bitwise XOR operation between Bitsets.""" + b1 = Bitset(2, [0x0101010101010101, 0x0101010101010101]) + b2 = Bitset(2, [0, 0]) + + result = b1 ^ b2 + assert result == Bitset(2, [0x0101010101010101, 0x0101010101010101]) + + +def test_fixed_bitset_invert(): + """Test bitwise inversion operation on FixedBitsets.""" + b = FixedBitset.of_size(64)(bytearray(b"\xff\xff\xff\xff\xff\xff\xff\xff")) + + inverted = ~b + assert inverted == FixedBitset.of_size(64).from_int(0) + + +def test_bitset_invert(): + """Test bitwise inversion operation on Bitsets.""" + b = Bitset(2, [0, 0]) + + inverted = ~b + assert inverted == Bitset(2, [-1, -1]) + + +def test_fixed_bitset_size_undefined(): + """Test that FixedBitset raises ValueError when size is not defined.""" + with pytest.raises(ValueError): + FixedBitset.from_int(0) + + with pytest.raises(ValueError): + FixedBitset(bytearray(b"\x00\x00\x00\x00")) + + +def test_bitset_len(): + """Test that FixedBitset has the correct length.""" + b = FixedBitset.of_size(64).from_int(0) + assert len(b) == 64 + + b = FixedBitset.of_size(8).from_int(0) + assert len(b) == 8 + + b = Bitset(2, [0, 0]) + assert len(b) == 128 + + +def test_fixed_bitset_operations_length_mismatch(): + """Test that FixedBitset operations raise ValueError when lengths don't match.""" + b1 = FixedBitset.of_size(64).from_int(0) + b2 = FixedBitset.of_size(8).from_int(0) + b3 = "not a bitset" + + with pytest.raises(ValueError): + b1 & b2 # type: ignore + + with pytest.raises(ValueError): + b1 | b2 # type: ignore + + with pytest.raises(ValueError): + b1 ^ b2 # type: ignore + + assert b1 != b3 + + +def test_bitset_operations_length_mismatch(): + """Test that Bitset operations raise ValueError when lengths don't match.""" + b1 = Bitset(2, [0, 0]) + b2 = Bitset.from_int(1) + b3 = "not a bitset" + + with pytest.raises(ValueError): + b1 & b2 # type: ignore + + with pytest.raises(ValueError): + b1 | b2 # type: ignore + + with pytest.raises(ValueError): + b1 ^ b2 # type: ignore + + assert b1 != b3 diff --git a/tests/mcproto/types/test_chat.py b/tests/mcproto/types/test_chat.py index 3c8f05a6..4098dd05 100644 --- a/tests/mcproto/types/test_chat.py +++ b/tests/mcproto/types/test_chat.py @@ -2,8 +2,9 @@ import pytest -from mcproto.types.chat import ChatMessage, RawChatMessage, RawChatMessageDict +from mcproto.types.chat import JSONTextComponent, RawTextComponent, RawTextComponentDict, TextComponent from tests.helpers import gen_serializable_test +from mcproto.types.nbt import CompoundNBT, StringNBT, ByteNBT, ListNBT @pytest.mark.parametrize( @@ -23,9 +24,9 @@ ), ], ) -def test_as_dict(raw: RawChatMessage, expected_dict: RawChatMessageDict): - """Test converting raw ChatMessage input into dict produces expected dict.""" - chat = ChatMessage(raw) +def test_as_dict(raw: RawTextComponent, expected_dict: RawTextComponentDict): + """Test converting raw TextComponent input into dict produces expected dict.""" + chat = JSONTextComponent(raw) assert chat.as_dict() == expected_dict @@ -44,15 +45,15 @@ def test_as_dict(raw: RawChatMessage, expected_dict: RawChatMessageDict): ), ], ) -def test_equality(raw1: RawChatMessage, raw2: RawChatMessage, expected_result: bool): - """Test comparing ChatMessage instances produces expected equality result.""" - assert (ChatMessage(raw1) == ChatMessage(raw2)) is expected_result +def test_equality(raw1: RawTextComponent, raw2: RawTextComponent, expected_result: bool): + """Test comparing TextComponent instances produces expected equality result.""" + assert (JSONTextComponent(raw1) == JSONTextComponent(raw2)) is expected_result gen_serializable_test( context=globals(), - cls=ChatMessage, - fields=[("raw", RawChatMessage)], + cls=JSONTextComponent, + fields=[("raw", RawTextComponent)], test_data=[ ( ("A Minecraft Server",), @@ -74,3 +75,46 @@ def test_equality(raw1: RawChatMessage, raw2: RawChatMessage, expected_result: b (([[]],), TypeError), ], ) + +gen_serializable_test( + context=globals(), + cls=TextComponent, + fields=[("raw", RawTextComponent)], + test_data=[ + (({"text": "abc"},), bytes(CompoundNBT([StringNBT("abc", name="text")]).serialize())), + ( + ([{"text": "abc"}, {"text": "def"}],), + bytes( + CompoundNBT( + [ + StringNBT("abc", name="text"), + ListNBT([CompoundNBT([StringNBT("def", name="text")])], name="extra"), + ] + ).serialize() + ), + ), + (("A Minecraft Server",), bytes(CompoundNBT([StringNBT("A Minecraft Server", name="text")]).serialize())), + ( + ([{"text": "abc", "extra": [{"text": "def"}]}, {"text": "ghi"}],), + bytes( + CompoundNBT( + [ + StringNBT("abc", name="text"), + ListNBT( + [ + CompoundNBT([StringNBT("def", name="text")]), + CompoundNBT([StringNBT("ghi", name="text")]), + ], + name="extra", + ), + ] + ).serialize() + ), + ), + # Type shitfuckery + (TypeError, bytes(CompoundNBT([CompoundNBT([ByteNBT(0, "Something")], "text")]).serialize())), + (KeyError, bytes(CompoundNBT([ByteNBT(0, "unknownkey")]).serialize())), + (TypeError, bytes(CompoundNBT([ListNBT([StringNBT("Expected str")], "text")]).serialize())), + (TypeError, bytes(CompoundNBT([StringNBT("Wrong type", "extra")]).serialize())), + ], +) diff --git a/tests/mcproto/types/test_identifier.py b/tests/mcproto/types/test_identifier.py new file mode 100644 index 00000000..d7cbb731 --- /dev/null +++ b/tests/mcproto/types/test_identifier.py @@ -0,0 +1,19 @@ +from mcproto.types.identifier import Identifier +from tests.helpers import gen_serializable_test + + +gen_serializable_test( + context=globals(), + cls=Identifier, + fields=[("namespace", str), ("path", str)], + test_data=[ + (("minecraft", "stone"), b"\x0fminecraft:stone"), + (("minecraft", "stone_brick"), b"\x15minecraft:stone_brick"), + (("minecraft", "stone_brick_slab"), b"\x1aminecraft:stone_brick_slab"), + (("minecr*ft", "stone_brick_slab_top"), ValueError), # Invalid namespace + (("minecraft", "stone_brick_slab_t@p"), ValueError), # Invalid path + (("", "something"), ValueError), # Empty namespace + (("minecraft", ""), ValueError), # Empty path + (("minecraft", "a" * 32767), ValueError), # Too long + ], +) diff --git a/tests/mcproto/types/test_nbt.py b/tests/mcproto/types/test_nbt.py index c405a658..671f1707 100644 --- a/tests/mcproto/types/test_nbt.py +++ b/tests/mcproto/types/test_nbt.py @@ -1,7 +1,7 @@ from __future__ import annotations import struct -from typing import Any, cast +from typing import Any, Dict, List, cast import pytest @@ -19,978 +19,380 @@ LongArrayNBT, LongNBT, NBTag, - NBTagType, - PayloadType, ShortNBT, StringNBT, ) +from tests.helpers import gen_serializable_test # region EndNBT - -def test_serialize_deserialize_end(): - """Test serialization/deserialization of NBT END tag.""" - output_bytes = EndNBT().serialize() - assert output_bytes == bytearray.fromhex("00") - - buffer = Buffer() - EndNBT().serialize_to(buffer) - assert buffer == bytearray.fromhex("00") - - buffer.clear() - EndNBT().serialize_to(buffer, with_name=False) - assert buffer == bytearray.fromhex("00") - - buffer = Buffer(bytearray.fromhex("00")) - assert EndNBT.deserialize(buffer) == EndNBT() +gen_serializable_test( + context=globals(), + cls=EndNBT, + fields=[], + test_data=[ + ((), b"\x00"), + (IOError, b"\x01"), + ], +) # endregion # region Numerical NBT tests - -@pytest.mark.parametrize( - ("nbt_class", "value", "expected_bytes"), - [ - (ByteNBT, 0, bytearray.fromhex("01 00")), - (ByteNBT, 1, bytearray.fromhex("01 01")), - (ByteNBT, 127, bytearray.fromhex("01 7F")), - (ByteNBT, -128, bytearray.fromhex("01 80")), - (ByteNBT, -1, bytearray.fromhex("01 FF")), - (ByteNBT, 12, bytearray.fromhex("01 0C")), - (ShortNBT, 0, bytearray.fromhex("02 00 00")), - (ShortNBT, 1, bytearray.fromhex("02 00 01")), - (ShortNBT, 32767, bytearray.fromhex("02 7F FF")), - (ShortNBT, -32768, bytearray.fromhex("02 80 00")), - (ShortNBT, -1, bytearray.fromhex("02 FF FF")), - (ShortNBT, 12, bytearray.fromhex("02 00 0C")), - (IntNBT, 0, bytearray.fromhex("03 00 00 00 00")), - (IntNBT, 1, bytearray.fromhex("03 00 00 00 01")), - (IntNBT, 2147483647, bytearray.fromhex("03 7F FF FF FF")), - (IntNBT, -2147483648, bytearray.fromhex("03 80 00 00 00")), - (IntNBT, -1, bytearray.fromhex("03 FF FF FF FF")), - (IntNBT, 12, bytearray.fromhex("03 00 00 00 0C")), - (LongNBT, 0, bytearray.fromhex("04 00 00 00 00 00 00 00 00")), - (LongNBT, 1, bytearray.fromhex("04 00 00 00 00 00 00 00 01")), - (LongNBT, (1 << 63) - 1, bytearray.fromhex("04 7F FF FF FF FF FF FF FF")), - (LongNBT, -(1 << 63), bytearray.fromhex("04 80 00 00 00 00 00 00 00")), - (LongNBT, -1, bytearray.fromhex("04 FF FF FF FF FF FF FF FF")), - (LongNBT, 12, bytearray.fromhex("04 00 00 00 00 00 00 00 0C")), - (FloatNBT, 1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 1.0))), - (FloatNBT, 0.25, bytearray.fromhex("05") + bytes(struct.pack(">f", 0.25))), - (FloatNBT, -1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", -1.0))), - (FloatNBT, 12.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 12.0))), - (DoubleNBT, 1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 1.0))), - (DoubleNBT, 0.25, bytearray.fromhex("06") + bytes(struct.pack(">d", 0.25))), - (DoubleNBT, -1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", -1.0))), - (DoubleNBT, 12.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 12.0))), - (ByteArrayNBT, b"", bytearray.fromhex("07 00 00 00 00")), - (ByteArrayNBT, b"\x00", bytearray.fromhex("07 00 00 00 01") + b"\x00"), - (ByteArrayNBT, b"\x00\x01", bytearray.fromhex("07 00 00 00 02") + b"\x00\x01"), - ( - ByteArrayNBT, - b"\x00\x01\x02", - bytearray.fromhex("07 00 00 00 03") + b"\x00\x01\x02", - ), - ( - ByteArrayNBT, - b"\x00\x01\x02\x03", - bytearray.fromhex("07 00 00 00 04") + b"\x00\x01\x02\x03", - ), - ( - ByteArrayNBT, - b"\xff" * 1024, - bytearray.fromhex("07 00 00 04 00") + b"\xff" * 1024, - ), - ( - ByteArrayNBT, - bytes((n - 1) * n * 2 % 256 for n in range(256)), - bytearray.fromhex("07 00 00 01 00") + bytes((n - 1) * n * 2 % 256 for n in range(256)), - ), - (StringNBT, "", bytearray.fromhex("08 00 00")), - (StringNBT, "test", bytearray.fromhex("08 00 04") + b"test"), - (StringNBT, "a" * 100, bytearray.fromhex("08 00 64") + b"a" * (100)), - (StringNBT, "&à@é", bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8")), - (ListNBT, [], bytearray.fromhex("09 00 00 00 00 00")), - (ListNBT, [ByteNBT(0)], bytearray.fromhex("09 01 00 00 00 01 00")), - ( - ListNBT, - [ShortNBT(127), ShortNBT(256)], - bytearray.fromhex("09 02 00 00 00 02 00 7F 01 00"), - ), - ( - ListNBT, - [ListNBT([ByteNBT(0)]), ListNBT([IntNBT(256)])], - bytearray.fromhex("09 09 00 00 00 02 01 00 00 00 01 00 03 00 00 00 01 00 00 01 00"), - ), - (CompoundNBT, [], bytearray.fromhex("0A 00")), - ( - CompoundNBT, - [ByteNBT(0, name="test")], - bytearray.fromhex("0A") + ByteNBT(0, name="test").serialize() + b"\x00", - ), - ( - CompoundNBT, - [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], - bytearray.fromhex("0A") + ShortNBT(128, "Short").serialize() + ByteNBT(-1, "Byte").serialize() + b"\x00", - ), - ( - CompoundNBT, - [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], - bytearray.fromhex("0A") + CompoundNBT([ByteNBT(0, name="Byte")], name="test").serialize() + b"\x00", - ), - ( - CompoundNBT, - [ - CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test"), - IntNBT(-1, "Int 2"), - ], - bytearray.fromhex("0A") - + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test").serialize() - + IntNBT(-1, "Int 2").serialize() - + b"\x00", - ), - (IntArrayNBT, [], bytearray.fromhex("0B 00 00 00 00")), - (IntArrayNBT, [0], bytearray.fromhex("0B 00 00 00 01 00 00 00 00")), - ( - IntArrayNBT, - [0, 1], - bytearray.fromhex("0B 00 00 00 02 00 00 00 00 00 00 00 01"), - ), - ( - IntArrayNBT, - [1, 2, 3], - bytearray.fromhex("0B 00 00 00 03 00 00 00 01 00 00 00 02 00 00 00 03"), - ), - (IntArrayNBT, [(1 << 31) - 1], bytearray.fromhex("0B 00 00 00 01 7F FF FF FF")), - ( - IntArrayNBT, - [(1 << 31) - 1, (1 << 31) - 2], - bytearray.fromhex("0B 00 00 00 02 7F FF FF FF 7F FF FF FE"), - ), - ( - IntArrayNBT, - [-1, -2, -3], - bytearray.fromhex("0B 00 00 00 03 FF FF FF FF FF FF FF FE FF FF FF FD"), - ), - ( - IntArrayNBT, - [12] * 1024, - bytearray.fromhex("0B 00 00 04 00") + b"\x00\x00\x00\x0c" * 1024, - ), - (LongArrayNBT, [], bytearray.fromhex("0C 00 00 00 00")), - ( - LongArrayNBT, - [0], - bytearray.fromhex("0C 00 00 00 01 00 00 00 00 00 00 00 00"), - ), - ( - LongArrayNBT, - [0, 1], - bytearray.fromhex("0C 00 00 00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01"), - ), - ( - LongArrayNBT, - [1, 2, 3], - bytearray.fromhex( - "0C 00 00 00 03 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03" - ), - ), - ( - LongArrayNBT, - [(1 << 63) - 1], - bytearray.fromhex("0C 00 00 00 01 7F FF FF FF FF FF FF FF"), - ), - ( - LongArrayNBT, - [(1 << 63) - 1, (1 << 63) - 2], - bytearray.fromhex("0C 00 00 00 02 7F FF FF FF FF FF FF FF 7F FF FF FF FF FF FF FE"), - ), - ( - LongArrayNBT, - [-1, -2, -3], - bytearray.fromhex( - "0C 00 00 00 03 FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FE FF FF FF FF FF FF FF FD" - ), - ), - ( - LongArrayNBT, - [12] * 1024, - bytearray.fromhex("0C 00 00 04 00") + b"\x00\x00\x00\x00\x00\x00\x00\x0c" * 1024, - ), +gen_serializable_test( + context=globals(), + cls=ByteNBT, + fields=[("payload", int), ("name", str)], + test_data=[ + ((0, "a"), b"\x01\x00\x01a\x00"), + ((1, "test"), b"\x01\x00\x04test\x01"), + ((127, "&à@é"), b"\x01\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f"), + ((-128, "test"), b"\x01\x00\x04test\x80"), + ((-1, "a" * 100), b"\x01\x00\x64" + b"a" * 100 + b"\xff"), + # Errors + (IOError, b"\x01\x00\x04test"), + (IOError, b"\x01\x00\x04tes"), + (IOError, b"\x01\x00"), + (IOError, b"\x01"), + # Wrong type + (TypeError, b"\x02\x00\x01a\x00"), + (TypeError, b"\xff\x00\x01a\x00"), + # Out of bounds + ((1 << 7, "a"), OverflowError), + ((-(1 << 7) - 1, "a"), OverflowError), + ((1 << 8, "a"), OverflowError), + ((-(1 << 8) - 1, "a"), OverflowError), + ((1000, "a"), OverflowError), + ((1.5, "a"), TypeError), ], ) -def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType, expected_bytes: bytes): - """Test serialization/deserialization of NBT tag without name.""" - # Test serialization - output_bytes = nbt_class(value).serialize(with_name=False) - output_bytes_no_type = nbt_class(value).serialize(with_type=False, with_name=False) - assert output_bytes == expected_bytes - assert output_bytes_no_type == expected_bytes[1:] - - buffer = Buffer() - nbt_class(value).serialize_to(buffer, with_name=False) - assert buffer == expected_bytes - - # Test deserialization - buffer = Buffer(expected_bytes) - assert NBTag.deserialize(buffer, with_name=False) == nbt_class(value) - - buffer = Buffer(expected_bytes[1:]) - assert nbt_class.deserialize(buffer, with_type=False, with_name=False) == nbt_class(value) - - buffer = Buffer(expected_bytes) - assert nbt_class.read_from(buffer, with_name=False) == nbt_class(value) - - buffer = Buffer(expected_bytes[1:]) - assert nbt_class.read_from(buffer, with_type=False, with_name=False) == nbt_class(value) - -@pytest.mark.parametrize( - ("nbt_class", "value", "name", "expected_bytes"), - [ - ( - ByteNBT, - 0, - "test", - bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("00"), - ), - ( - ByteNBT, - 1, - "a", - bytearray.fromhex("01") + b"\x00\x01a" + bytearray.fromhex("01"), - ), - ( - ByteNBT, - 127, - "&à@é", - bytearray.fromhex("01 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F"), - ), - ( - ByteNBT, - -128, - "test", - bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("80"), - ), - ( - ByteNBT, - 12, - "a" * 100, - bytearray.fromhex("01") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("0C"), - ), - ( - ShortNBT, - 0, - "test", - bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("00 00"), - ), - ( - ShortNBT, - 1, - "a", - bytearray.fromhex("02") + b"\x00\x01a" + bytearray.fromhex("00 01"), - ), - ( - ShortNBT, - 32767, - "&à@é", - bytearray.fromhex("02 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF"), - ), - ( - ShortNBT, - -32768, - "test", - bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("80 00"), - ), - ( - ShortNBT, - 12, - "a" * 100, - bytearray.fromhex("02") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 0C"), - ), - ( - IntNBT, - 0, - "test", - bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), - ), - ( - IntNBT, - 1, - "a", - bytearray.fromhex("03") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01"), - ), - ( - IntNBT, - 2147483647, - "&à@é", - bytearray.fromhex("03 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF"), - ), - ( - IntNBT, - -2147483648, - "test", - bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00"), - ), - ( - IntNBT, - 12, - "a" * 100, - bytearray.fromhex("03") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 0C"), - ), - ( - LongNBT, - 0, - "test", - bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00 00 00 00"), - ), - ( - LongNBT, - 1, - "a", - bytearray.fromhex("04") + b"\x00\x01a" + bytearray.fromhex("00 00 00 00 00 00 00 01"), - ), - ( - LongNBT, - (1 << 63) - 1, - "&à@é", - bytearray.fromhex("04 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF FF FF FF FF"), - ), - ( - LongNBT, - -1 << 63, - "test", - bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00 00 00 00 00"), - ), - ( - LongNBT, - 12, - "a" * 100, - bytearray.fromhex("04") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 00 00 00 00 0C"), - ), - ( - FloatNBT, - 1.0, - "test", - bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 1.0)), - ), - ( - FloatNBT, - 0.25, - "a", - bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 0.25)), - ), - ( - FloatNBT, - -1.0, - "&à@é", - bytearray.fromhex("05 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0)), - ), - ( - FloatNBT, - 12.0, - "test", - bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 12.0)), - ), - ( - DoubleNBT, - 1.0, - "test", - bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 1.0)), - ), - ( - DoubleNBT, - 0.25, - "a", - bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 0.25)), - ), - ( - DoubleNBT, - -1.0, - "&à@é", - bytearray.fromhex("06 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0)), - ), - ( - DoubleNBT, - 12.0, - "test", - bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 12.0)), - ), - ( - ByteArrayNBT, - b"", - "test", - bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), - ), - ( - ByteArrayNBT, - b"\x00", - "a", - bytearray.fromhex("07") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00", - ), - ( - ByteArrayNBT, - b"\x00\x01", - "&à@é", - bytearray.fromhex("07 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("00 00 00 02") + b"\x00\x01", - ), - ( - ByteArrayNBT, - b"\x00\x01\x02", - "test", - bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 03") + b"\x00\x01\x02", - ), - ( - ByteArrayNBT, - b"\xff" * 1024, - "a" * 100, - bytearray.fromhex("07") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 04 00") + b"\xff" * 1024, - ), - ( - StringNBT, - "", - "test", - bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 00"), - ), - ( - StringNBT, - "test", - "a", - bytearray.fromhex("08") + b"\x00\x01a" + bytearray.fromhex("00 04") + b"test", - ), - ( - StringNBT, - "a" * 100, - "&à@é", - bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("00 64") + b"a" * 100, - ), - ( - StringNBT, - "&à@é", - "test", - bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 06") + bytes("&à@é", "utf-8"), - ), - ( - ListNBT, - [], - "test", - bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00"), - ), - ( - ListNBT, - [ByteNBT(-1)], - "a", - bytearray.fromhex("09") + b"\x00\x01a" + bytearray.fromhex("01 00 00 00 01 FF"), - ), - ( - ListNBT, - [ShortNBT(127), ShortNBT(256)], - "test", - bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("02 00 00 00 02 00 7F 01 00"), - ), - ( - ListNBT, - [ListNBT([ByteNBT(-1)]), ListNBT([IntNBT(256)])], - "a", - bytearray.fromhex("09") - + b"\x00\x01a" - + bytearray.fromhex("09 00 00 00 02 01 00 00 00 01 FF 03 00 00 00 01 00 00 01 00"), - ), - ( - CompoundNBT, - [], - "test", - bytearray.fromhex("0A") + b"\x00\x04test" + bytearray.fromhex("00"), - ), - ( - CompoundNBT, - [ByteNBT(0, name="Byte")], - "test", - bytearray.fromhex("0A") + b"\x00\x04test" + ByteNBT(0, name="Byte").serialize() + b"\x00", - ), - ( - CompoundNBT, - [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], - "test", - bytearray.fromhex("0A") - + b"\x00\x04test" - + ShortNBT(128, "Short").serialize() - + ByteNBT(-1, "Byte").serialize() - + b"\x00", - ), - ( - CompoundNBT, - [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], - "test", - bytearray.fromhex("0A") - + b"\x00\x04test" - + CompoundNBT([ByteNBT(0, name="Byte")], "test").serialize() - + b"\x00", - ), - ( - CompoundNBT, - [ListNBT([ByteNBT(0)], name="List")], - "test", - bytearray.fromhex("0A") + b"\x00\x04test" + ListNBT([ByteNBT(0)], name="List").serialize() + b"\x00", - ), - ( - IntArrayNBT, - [], - "test", - bytearray.fromhex("0B") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), - ), - ( - IntArrayNBT, - [0], - "a", - bytearray.fromhex("0B") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00\x00\x00\x00", - ), - ( - IntArrayNBT, - [0, 1], - "&à@é", - bytearray.fromhex("0B 00 06") - + bytes("&à@é", "utf-8") - + bytearray.fromhex("00 00 00 02") - + b"\x00\x00\x00\x00\x00\x00\x00\x01", - ), - ( - IntArrayNBT, - [1, 2, 3], - "test", - bytearray.fromhex("0B") - + b"\x00\x04test" - + bytearray.fromhex("00 00 00 03") - + b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03", - ), - ( - IntArrayNBT, - [(1 << 31) - 1], - "a" * 100, - bytearray.fromhex("0B") - + b"\x00\x64" - + b"a" * 100 - + bytearray.fromhex("00 00 00 01") - + b"\x7f\xff\xff\xff", - ), - ( - LongArrayNBT, - [], - "test", - bytearray.fromhex("0C") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), - ), - ( - LongArrayNBT, - [0], - "a", - bytearray.fromhex("0C") - + b"\x00\x01a" - + bytearray.fromhex("00 00 00 01") - + b"\x00\x00\x00\x00\x00\x00\x00\x00", - ), - ( - LongArrayNBT, - [0, 1], - "&à@é", - bytearray.fromhex("0C 00 06") - + bytes("&à@é", "utf-8") - + bytearray.fromhex("00 00 00 02") - + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - ), - ( - LongArrayNBT, - [1, 2, 3], - "test", - bytearray.fromhex("0C") - + b"\x00\x04test" - + bytearray.fromhex("00 00 00 03") - + bytearray.fromhex("00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03"), - ), - ( - LongArrayNBT, - [(1 << 63) - 1] * 100, - "a" * 100, - bytearray.fromhex("0C") - + b"\x00\x64" - + b"a" * 100 - + bytearray.fromhex("00 00 00 64") - + b"\x7f\xff\xff\xff\xff\xff\xff\xff" * 100, - ), +gen_serializable_test( + context=globals(), + cls=ShortNBT, + fields=[("payload", int), ("name", str)], + test_data=[ + ((0, "a"), b"\x02\x00\x01a\x00\x00"), + ((1, "test"), b"\x02\x00\x04test\x00\x01"), + ((32767, "&à@é"), b"\x02\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f\xff"), + ((-32768, "test"), b"\x02\x00\x04test\x80\x00"), + ((-1, "a" * 100), b"\x02\x00\x64" + b"a" * 100 + b"\xff\xff"), + # Errors + (IOError, b"\x02\x00\x04test"), + (IOError, b"\x02\x00\x04tes"), + (IOError, b"\x02\x00"), + (IOError, b"\x02"), + # Out of bounds + ((1 << 15, "a"), OverflowError), + ((-(1 << 15) - 1, "a"), OverflowError), + ((1 << 16, "a"), OverflowError), + ((-(1 << 16) - 1, "a"), OverflowError), + ((int(1e10), "a"), OverflowError), ], ) -def test_serialize_deserialize(nbt_class: type[NBTag], value: PayloadType, name: str, expected_bytes: bytes): - """Test serialization/deserialization of NBT tag with name.""" - # Test serialization - output_bytes = nbt_class(value, name).serialize() - output_bytes_no_type = nbt_class(value, name).serialize(with_type=False) - assert output_bytes == expected_bytes - assert output_bytes_no_type == expected_bytes[1:] - - buffer = Buffer() - nbt_class(value, name).serialize_to(buffer) - assert buffer == expected_bytes - - # Test deserialization - buffer = Buffer(expected_bytes * 2) - assert buffer.remaining == len(expected_bytes) * 2 - assert NBTag.deserialize(buffer) == nbt_class(value, name=name) - assert buffer.remaining == len(expected_bytes) - assert NBTag.deserialize(buffer) == nbt_class(value, name=name) - assert buffer.remaining == 0 - - buffer = Buffer(expected_bytes[1:]) - assert nbt_class.deserialize(buffer, with_type=False) == nbt_class(value, name=name) - buffer = Buffer(expected_bytes) - assert nbt_class.read_from(buffer) == nbt_class(value, name=name) - - buffer = Buffer(expected_bytes[1:]) - assert nbt_class.read_from(buffer, with_type=False) == nbt_class(value, name=name) - - -@pytest.mark.parametrize( - ("nbt_class", "size", "tag"), - [ - (ByteNBT, 8, NBTagType.BYTE), - (ShortNBT, 16, NBTagType.SHORT), - (IntNBT, 32, NBTagType.INT), - (LongNBT, 64, NBTagType.LONG), +gen_serializable_test( + context=globals(), + cls=IntNBT, + fields=[("payload", int), ("name", str)], + test_data=[ + ((0, "a"), b"\x03\x00\x01a\x00\x00\x00\x00"), + ((1, "test"), b"\x03\x00\x04test\x00\x00\x00\x01"), + ((2147483647, "&à@é"), b"\x03\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f\xff\xff\xff"), + ((-2147483648, "test"), b"\x03\x00\x04test\x80\x00\x00\x00"), + ((-1, "a" * 100), b"\x03\x00\x64" + b"a" * 100 + b"\xff\xff\xff\xff"), + # Errors + (IOError, b"\x03\x00\x04test"), + (IOError, b"\x03\x00\x04tes"), + (IOError, b"\x03\x00"), + (IOError, b"\x03"), + # Out of bounds + ((1 << 31, "a"), OverflowError), + ((-(1 << 31) - 1, "a"), OverflowError), + ((1 << 32, "a"), OverflowError), + ((-(1 << 32) - 1, "a"), OverflowError), + ((int(1e30), "a"), OverflowError), ], ) -def test_serialize_deserialize_numerical_fail(nbt_class: type[NBTag], size: int, tag: NBTagType): - """Test serialization/deserialization of NBT NUM tag with invalid value.""" - # Out of bounds - with pytest.raises(OverflowError): - nbt_class(1 << (size - 1)).serialize(with_name=False) - - with pytest.raises(OverflowError): - nbt_class(-(1 << (size - 1)) - 1).serialize(with_name=False) - - # Deserialization - buffer = Buffer(bytearray([tag.value + 1] + [0] * (size // 8))) - with pytest.raises(TypeError): # Tries to read a nbt_class, but it's one higher - nbt_class.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([tag.value] + [0] * ((size // 8) - 1))) - with pytest.raises(IOError): - nbt_class.read_from(buffer, with_name=False) - - buffer = Buffer(bytearray([tag.value, 0, 0] + [0] * (size // 8))) - assert nbt_class.read_from(buffer, with_name=True) == nbt_class(0) - - -# endregion - -# region FloatNBT - - -def test_serialize_deserialize_float_fail(): - """Test serialization/deserialization of NBT FLOAT tag with invalid value.""" - with pytest.raises(struct.error): - FloatNBT("test").serialize(with_name=False) - - with pytest.raises(OverflowError): - FloatNBT(1e39, "test").serialize() - - with pytest.raises(OverflowError): - FloatNBT(-1e39, "test").serialize() - - # Deserialization - buffer = Buffer(bytearray([NBTagType.BYTE] + [0] * 4)) - with pytest.raises(TypeError): # Tries to read a FloatNBT, but it's a ByteNBT - FloatNBT.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.FLOAT, 0, 0, 0])) - with pytest.raises(IOError): - FloatNBT.read_from(buffer, with_name=False) - - -# endregion -# region DoubleNBT - - -def test_serialize_deserialize_double_fail(): - """Test serialization/deserialization of NBT DOUBLE tag with invalid value.""" - with pytest.raises(struct.error): - DoubleNBT("test").serialize(with_name=False) - - # Deserialization - buffer = Buffer(bytearray([0x01] + [0] * 8)) - with pytest.raises(TypeError): # Tries to read a DoubleNBT, but it's a ByteNBT - DoubleNBT.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.DOUBLE, 0, 0, 0, 0, 0, 0, 0])) - with pytest.raises(IOError): - DoubleNBT.read_from(buffer, with_name=False) - - -# endregion -# region ByteArrayNBT - - -def test_serialize_deserialize_bytearray_fail(): - """Test serialization/deserialization of NBT BYTEARRAY tag with invalid value.""" - # Deserialization - buffer = Buffer(bytearray([0x01] + [0] * 4)) - with pytest.raises(TypeError): # Tries to read a ByteArrayNBT, but it's a ByteNBT - ByteArrayNBT.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0])) # Missing length bytes - with pytest.raises(IOError): - ByteArrayNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 1])) # Missing data bytes - with pytest.raises(IOError): - ByteArrayNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 2, 0])) # Missing data bytes - with pytest.raises(IOError): - ByteArrayNBT.read_from(buffer, with_name=False) - - # Negative length - buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0xFF, 0xFF, 0xFF, 0xFF])) # length = -1 - with pytest.raises(ValueError): - ByteArrayNBT.deserialize(buffer, with_name=False) - - -# endregion -# region StringNBT - - -def test_serialize_deserialize_string_fail(): - """Test serialization/deserialization of NBT STRING tag with invalid value.""" - # Deserialization - buffer = Buffer(bytearray([0x01, 0, 0])) - with pytest.raises(TypeError): # Tries to read a StringNBT, but it's a ByteNBT - StringNBT.deserialize(buffer, with_name=False) - - # Not enough data for the length - buffer = Buffer(bytearray([NBTagType.STRING, 0])) - with pytest.raises(IOError): - StringNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.STRING, 0, 1])) - with pytest.raises(IOError): - StringNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.STRING, 0, 2, 0])) - with pytest.raises(IOError): - StringNBT.read_from(buffer, with_name=False) - - # Negative length - buffer = Buffer(bytearray([NBTagType.STRING, 0xFF, 0xFF])) # length = -1 - with pytest.raises(ValueError): - StringNBT.deserialize(buffer, with_name=False) - - # Invalid UTF-8 - buffer = Buffer(bytearray([NBTagType.STRING, 0, 1, 0xC0, 0x80])) - with pytest.raises(UnicodeDecodeError): - StringNBT.read_from(buffer, with_name=False) +gen_serializable_test( + context=globals(), + cls=LongNBT, + fields=[("payload", int), ("name", str)], + test_data=[ + ((0, "a"), b"\x04\x00\x01a\x00\x00\x00\x00\x00\x00\x00\x00"), + ((1, "test"), b"\x04\x00\x04test\x00\x00\x00\x00\x00\x00\x00\x01"), + (((1 << 63) - 1, "&à@é"), b"\x04\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f\xff\xff\xff\xff\xff\xff\xff"), + ((-1 << 63, "test"), b"\x04\x00\x04test\x80\x00\x00\x00\x00\x00\x00\x00"), + ((-1, "a" * 100), b"\x04\x00\x64" + b"a" * 100 + b"\xff\xff\xff\xff\xff\xff\xff\xff"), + # Errors + (IOError, b"\x04\x00\x04test"), + (IOError, b"\x04\x00\x04tes"), + (IOError, b"\x04\x00"), + (IOError, b"\x04"), + # Out of bounds + ((1 << 63, "a"), OverflowError), + ((-(1 << 63) - 1, "a"), OverflowError), + ((1 << 64, "a"), OverflowError), + ((-(1 << 64) - 1, "a"), OverflowError), + ], +) # endregion -# region ListNBT - - -@pytest.mark.parametrize( - ("payload", "error"), - [ - ([ByteNBT(0), IntNBT(0)], ValueError), - ([ByteNBT(0), "test"], ValueError), - ([ByteNBT(0), None], ValueError), - ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags - ([ByteNBT(128), ByteNBT(-1)], OverflowError), # Check for error propagation +# region Floating point NBT tests +gen_serializable_test( + context=globals(), + cls=FloatNBT, + fields=[("payload", float), ("name", str)], + test_data=[ + ((1.0, "a"), b"\x05\x00\x01a" + bytes(struct.pack(">f", 1.0))), + ((0.5, "test"), b"\x05\x00\x04test" + bytes(struct.pack(">f", 0.5))), # has to be convertible to float exactly + ((-1.0, "&à@é"), b"\x05\x00\x06" + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0))), + ((12.0, "a" * 100), b"\x05\x00\x64" + b"a" * 100 + bytes(struct.pack(">f", 12.0))), + ((1, "a"), b"\x05\x00\x01a" + bytes(struct.pack(">f", 1.0))), + # Errors + (IOError, b"\x05\x00\x04test"), + (IOError, b"\x05\x00\x04tes"), + (IOError, b"\x05\x00"), + (IOError, b"\x05"), + # Wrong type + (("1.5", "a"), TypeError), ], ) -def test_serialize_list_fail(payload: PayloadType, error: type[Exception]): - """Test serialization of NBT LIST tag with invalid value.""" - with pytest.raises(error): - ListNBT(payload, "test").serialize() - - -def test_deserialize_list_fail(): - """Test deserialization of NBT LIST tag with invalid value.""" - # Wrong tag type - buffer = Buffer(bytearray([0x09, 255, 0, 0, 0, 1, 0])) - with pytest.raises(TypeError): - ListNBT.deserialize(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([0x09, 1, 0, 0, 0, 1])) - with pytest.raises(IOError): - ListNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([0x09, 1, 0, 0, 0])) - with pytest.raises(IOError): - ListNBT.read_from(buffer, with_name=False) - +gen_serializable_test( + context=globals(), + cls=DoubleNBT, + fields=[("payload", float), ("name", str)], + test_data=[ + ((1.0, "a"), b"\x06\x00\x01a" + bytes(struct.pack(">d", 1.0))), + ((3.14, "test"), b"\x06\x00\x04test" + bytes(struct.pack(">d", 3.14))), + ((-1.0, "&à@é"), b"\x06\x00\x06" + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0))), + ((12.0, "a" * 100), b"\x06\x00\x64" + b"a" * 100 + bytes(struct.pack(">d", 12.0))), + # Errors + (IOError, b"\x06\x00\x04test\x01"), + (IOError, b"\x06\x00\x04test"), + (IOError, b"\x06\x00\x04tes"), + (IOError, b"\x06\x00"), + (IOError, b"\x06"), + ], +) # endregion -# region CompoundNBT - - -@pytest.mark.parametrize( - ("payload", "error"), - [ - ([ByteNBT(0, name="Hello"), IntNBT(0)], ValueError), - ([ByteNBT(0, name="hi"), "test"], ValueError), - ([ByteNBT(0, name="hi"), None], ValueError), - ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags - ( - [ByteNBT(128, name="Jello"), ByteNBT(-1, name="Bonjour")], - OverflowError, - ), # Check for error propagation +# region Variable Length NBT tests +gen_serializable_test( + context=globals(), + cls=ByteArrayNBT, + fields=[("payload", bytes), ("name", str)], + test_data=[ + ((b"", "a"), b"\x07\x00\x01a\x00\x00\x00\x00"), + ((b"\x00", "test"), b"\x07\x00\x04test\x00\x00\x00\x01\x00"), + ((b"\x00\x01", "&à@é"), b"\x07\x00\x06" + bytes("&à@é", "utf-8") + b"\x00\x00\x00\x02\x00\x01"), + ((b"\x00\x01\x02", "test"), b"\x07\x00\x04test\x00\x00\x00\x03\x00\x01\x02"), + ((b"\xff" * 1024, "a" * 100), b"\x07\x00\x64" + b"a" * 100 + b"\x00\x00\x04\x00" + b"\xff" * 1024), + ((b"Hello World", "test"), b"\x07\x00\x04test\x00\x00\x00\x0b" + b"Hello World"), + ((bytearray(b"Hello World"), "test"), b"\x07\x00\x04test\x00\x00\x00\x0b" + b"Hello World"), + # Errors + (IOError, b"\x07\x00\x04test"), + (IOError, b"\x07\x00\x04tes"), + (IOError, b"\x07\x00"), + (IOError, b"\x07"), + (IOError, b"\x07\x00\x01a\x00\x01"), + (IOError, b"\x07\x00\x01a\x00\x00\x00\xff"), + # Negative length + (ValueError, b"\x07\x00\x01a\xff\xff\xff\xff"), + # Wrong type + ((1, "a"), TypeError), ], ) -def test_serialize_compound_fail(payload: PayloadType, error: type[Exception]): - """Test serialization of NBT COMPOUND tag with invalid value.""" - with pytest.raises(error): - CompoundNBT(payload, "test").serialize() - - # Double name - with pytest.raises(ValueError): - CompoundNBT([ByteNBT(0, name="test"), ByteNBT(0, name="test")], "comp").serialize() - -def test_deseialize_compound_fail(): - """Test deserialization of NBT COMPOUND tag with invalid value.""" - # Not enough data - buffer = Buffer(bytearray([NBTagType.COMPOUND, 0x01])) - with pytest.raises(IOError): - CompoundNBT.read_from(buffer, with_name=False) - - # Not enough data - buffer = Buffer(bytearray([NBTagType.COMPOUND])) - with pytest.raises(IOError): - CompoundNBT.read_from(buffer, with_name=False) +gen_serializable_test( + context=globals(), + cls=StringNBT, + fields=[("payload", str), ("name", str)], + test_data=[ + (("", "a"), b"\x08\x00\x01a\x00\x00"), + (("test", "a"), b"\x08\x00\x01a\x00\x04" + b"test"), + (("a" * 100, "&à@é"), b"\x08\x00\x06" + bytes("&à@é", "utf-8") + b"\x00\x64" + b"a" * 100), + (("&à@é", "test"), b"\x08\x00\x04test\x00\x06" + bytes("&à@é", "utf-8")), + # Errors + (IOError, b"\x08\x00\x04test"), + (IOError, b"\x08\x00\x04tes"), + (IOError, b"\x08\x00"), + (IOError, b"\x08"), + # Negative length + (ValueError, b"\x08\xff\xff\xff\xff"), + # Unicode decode error + (UnicodeDecodeError, b"\x08\x00\x01a\x00\x01\xff"), + # String too long + (("a" * 32768, "b"), ValueError), + # Wrong type + ((1, "a"), TypeError), + ], +) - # Wrong tag type - buffer = Buffer(bytearray([15])) - with pytest.raises(TypeError): - NBTag.deserialize(buffer) +gen_serializable_test( + context=globals(), + cls=ListNBT, + fields=[("payload", list), ("name", str)], + test_data=[ + # Here we only want to test ListNBT related stuff + (([], "a"), b"\x09\x00\x01a\x00\x00\x00\x00\x00"), + (([ByteNBT(-1)], "a"), b"\x09\x00\x01a\x01\x00\x00\x00\x01\xff"), + (([ListNBT([])], "a"), b"\x09\x00\x01a\x09\x00\x00\x00\x01" + ListNBT([]).serialize()[1:]), + (([ListNBT([ByteNBT(6)])], "a"), b"\x09\x00\x01a\x09\x00\x00\x00\x01" + ListNBT([ByteNBT(6)]).serialize()[1:]), + ( + ([ListNBT([ByteNBT(-1)]), ListNBT([IntNBT(1234)])], "a"), + b"\x09\x00\x01a\x09\x00\x00\x00\x02" + + ListNBT([ByteNBT(-1)]).serialize()[1:] + + ListNBT([IntNBT(1234)]).serialize()[1:], + ), + ( + ([ListNBT([ByteNBT(-1)]), ListNBT([IntNBT(128), IntNBT(8)])], "a"), + b"\x09\x00\x01a\x09\x00\x00\x00\x02" + + ListNBT([ByteNBT(-1)]).serialize()[1:] + + ListNBT([IntNBT(128), IntNBT(8)]).serialize()[1:], + ), + # Errors + # Not enough data + (IOError, b"\x09\x00\x01a"), + (IOError, b"\x09\x00\x01a\x01"), + (IOError, b"\x09\x00\x01a\x01\x00"), + (IOError, b"\x09\x00\x01a\x01\x00\x00\x00\x01"), + (IOError, b"\x09\x00\x01a\x01\x00\x00\x00\x03\x01"), + # Invalid tag type + (TypeError, b"\x09\x00\x01a\xff\x00\x00\x01\x00"), + # Not NBTags + (([1, 2, 3], "a"), TypeError), + # Not the same tag type + (([ByteNBT(0), IntNBT(0)], "a"), TypeError), + # Contains named tags + (([ByteNBT(0, name="Byte")], "a"), ValueError), + # Wrong type + ((1, "a"), TypeError), + ], +) +gen_serializable_test( + context=globals(), + cls=CompoundNBT, + fields=[("payload", list), ("name", str)], + test_data=[ + (([], "a"), b"\x0a\x00\x01a\x00"), + (([ByteNBT(0, name="Byte")], "a"), b"\x0a\x00\x01a" + ByteNBT(0, name="Byte").serialize() + b"\x00"), + ( + ([ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], "a"), + b"\x0a\x00\x01a" + ShortNBT(128, "Short").serialize() + ByteNBT(-1, "Byte").serialize() + b"\x00", + ), + ( + ([CompoundNBT([ByteNBT(0, name="Byte")], name="test")], "a"), + b"\x0a\x00\x01a" + CompoundNBT([ByteNBT(0, name="Byte")], "test").serialize() + b"\x00", + ), + ( + ([ListNBT([ByteNBT(0)] * 3, name="List")], "a"), + b"\x0a\x00\x01a" + ListNBT([ByteNBT(0)] * 3, name="List").serialize() + b"\x00", + ), + # Errors + # Not enough data + (IOError, b"\x0a\x00\x01a"), + (IOError, b"\x0a\x00\x01a\x01"), + # All muse be NBTags + (([0, 1, 2], "a"), TypeError), + # All with a name + (([ByteNBT(0)], "a"), ValueError), + # Must be unique + (([ByteNBT(0, name="Byte"), ByteNBT(0, name="Byte")], "a"), ValueError), + # Wrong type + ((1, "a"), TypeError), + ], +) -def test_nbtag_deserialize_compound(): - """Test deserialization of NBT COMPOUND tag from the NBTag class.""" - buf = Buffer(bytearray([0x00])) - assert NBTag.deserialize(buf, with_type=False, with_name=False) == CompoundNBT([]) +gen_serializable_test( + context=globals(), + cls=IntArrayNBT, + fields=[("payload", list), ("name", str)], + test_data=[ + (([], "a"), b"\x0b\x00\x01a\x00\x00\x00\x00"), + (([0], "a"), b"\x0b\x00\x01a\x00\x00\x00\x01\x00\x00\x00\x00"), + (([0, 1], "a"), b"\x0b\x00\x01a\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01"), + (([1, 2, 3], "a"), b"\x0b\x00\x01a\x00\x00\x00\x03\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03"), + (([(1 << 31) - 1], "a"), b"\x0b\x00\x01a\x00\x00\x00\x01\x7f\xff\xff\xff"), + (([-1, -2, -3], "a"), b"\x0b\x00\x01a\x00\x00\x00\x03\xff\xff\xff\xff\xff\xff\xff\xfe\xff\xff\xff\xfd"), + # Errors + # Not enough data + (IOError, b"\x0b\x00\x01a"), + (IOError, b"\x0b\x00\x01a\x01"), + (IOError, b"\x0b\x00\x01a\x00\x00\x00\x01"), + (IOError, b"\x0b\x00\x01a\x00\x00\x00\x03\x01"), + # Must contain ints only + ((["a"], "a"), TypeError), + (([IntNBT(0)], "a"), TypeError), + (([1 << 31], "a"), OverflowError), + (([-(1 << 31) - 1], "a"), OverflowError), + # Wrong type + ((1, "a"), TypeError), + ], +) +gen_serializable_test( + context=globals(), + cls=LongArrayNBT, + fields=[("payload", list), ("name", str)], + test_data=[ + (([], "a"), b"\x0c\x00\x01a\x00\x00\x00\x00"), + (([0], "a"), b"\x0c\x00\x01a\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00"), + ( + ([0, 1], "a"), + b"\x0c\x00\x01a\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + ), + (([(1 << 63) - 1], "a"), b"\x0c\x00\x01a\x00\x00\x00\x01\x7f\xff\xff\xff\xff\xff\xff\xff"), + ( + ([-1, -2], "a"), + b"\x0c\x00\x01a\x00\x00\x00\x02\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe", + ), + # Not enough data + (IOError, b"\x0c\x00\x01a"), + (IOError, b"\x0c\x00\x01a\x01"), + (IOError, b"\x0c\x00\x01a\x00\x00\x00\x01"), + (IOError, b"\x0c\x00\x01a\x00\x00\x00\x03\x01"), + # Must contain ints only + ((["a"], "a"), TypeError), + (([LongNBT(0)], "a"), TypeError), + (([1 << 63], "a"), OverflowError), + (([-(1 << 63) - 1], "a"), OverflowError), + ], +) - buf = Buffer(bytearray.fromhex("0A 00 01 61 01 00 01 62 00 00")) - assert NBTag.deserialize(buf) == CompoundNBT([ByteNBT(0, name="b")], name="a") +# endregion +# region CompoundNBT def test_equality_compound(): """Test equality of CompoundNBT.""" - comp1 = CompoundNBT( - [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], - "comp", - ) - comp2 = CompoundNBT( - [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], - "comp", - ) + comp1 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") assert comp1 == comp2 comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2")], "comp") assert comp1 != comp2 - comp2 = CompoundNBT( - [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], - "comp", - ) + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], "comp") assert comp1 != comp2 - comp2 = CompoundNBT( - [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], - "comp2", - ) + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp2") assert comp1 != comp2 assert comp1 != ByteNBT(0, name="comp") # endregion -# region IntArrayNBT - - -@pytest.mark.parametrize( - ("payload", "error"), - [ - ([0, "test"], ValueError), - ([0, None], ValueError), - ([0, 1 << 31], OverflowError), - ([0, -(1 << 31) - 1], OverflowError), - ], -) -def test_serialize_intarray_fail(payload: PayloadType, error: type[Exception]): - """Test serialization of NBT INTARRAY tag with invalid value.""" - with pytest.raises(error): - IntArrayNBT(payload, "test").serialize() - - -def test_deserialize_intarray_fail(): - """Test deserialization of NBT INTARRAY tag with invalid value.""" - # Not enough data for 1 element - buffer = Buffer(bytearray([0x0B, 0, 0, 0, 1, 0, 0, 0])) - with pytest.raises(IOError): - IntArrayNBT.deserialize(buffer, with_name=False) - - # Not enough data for the size - buffer = Buffer(bytearray([0x0B, 0, 0, 0])) - with pytest.raises(IOError): - IntArrayNBT.read_from(buffer, with_name=False) - - # Not enough data to start the 2nd element - buffer = Buffer(bytearray([0x0B, 0, 0, 0, 2, 1, 0, 0, 0])) - with pytest.raises(IOError): - IntArrayNBT.read_from(buffer, with_name=False) - - -# endregion -# region LongArrayNBT - - -@pytest.mark.parametrize( - ("payload", "error"), - [ - ([0, "test"], ValueError), - ([0, None], ValueError), - ([0, 1 << 63], OverflowError), - ([0, -(1 << 63) - 1], OverflowError), - ], -) -def test_serialize_deserialize_longarray_fail(payload: PayloadType, error: type[Exception]): - """Test serialization/deserialization of NBT LONGARRAY tag with invalid value.""" - with pytest.raises(error): - LongArrayNBT(payload, "test").serialize() - -def test_deserialize_longarray_fail(): - """Test deserialization of NBT LONGARRAY tag with invalid value.""" - # Not enough data for 1 element - buffer = Buffer(bytearray([0x0C, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])) - with pytest.raises(IOError): - LongArrayNBT.deserialize(buffer, with_name=False) +# region ListNBT - # Not enough data for the size - buffer = Buffer(bytearray([0x0C, 0, 0, 0])) - with pytest.raises(IOError): - LongArrayNBT.read_from(buffer, with_name=False) - # Not enough data to start the 2nd element - buffer = Buffer(bytearray([0x0C, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0])) - with pytest.raises(IOError): - LongArrayNBT.read_from(buffer, with_name=False) +def test_intarray_negative_length(): + """Test IntArray with negative length.""" + buffer = Buffer(b"\x0b\x00\x01a\xff\xff\xff\xff") + assert IntArrayNBT.read_from(buffer) == IntArrayNBT([], "a") # endregion @@ -1019,7 +421,7 @@ def test_nbt_helloworld(): def test_nbt_bigfile(): """Test serialization/deserialization of a big NBT tag. - Slightly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. + Slighly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. Source data: https://wiki.vg/NBT#Example. """ data = "0a00054c6576656c0400086c6f6e67546573747fffffffffffffff02000973686f7274546573747fff08000a737472696e6754657374002948454c4c4f20574f524c4420544849532049532041205445535420535452494e4720c385c384c39621050009666c6f6174546573743eff1832030007696e74546573747fffffff0a00146e657374656420636f6d706f756e6420746573740a000368616d0800046e616d65000648616d70757305000576616c75653f400000000a00036567670800046e616d6500074567676265727405000576616c75653f00000000000c000f6c6973745465737420286c6f6e672900000005000000000000000b000000000000000c000000000000000d000000000000000e7fffffffffffffff0b000e6c697374546573742028696e7429000000047fffffff7ffffffe7ffffffd7ffffffc0900136c697374546573742028636f6d706f756e64290a000000020800046e616d65000f436f6d706f756e642074616720233004000a637265617465642d6f6e000001265237d58d000800046e616d65000f436f6d706f756e642074616720233104000a637265617465642d6f6e000001265237d58d0001000862797465546573747f07006562797465417272617954657374202874686520666972737420313030302076616c756573206f6620286e2a6e2a3235352b6e2a3729253130302c207374617274696e672077697468206e3d302028302c2036322c2033342c2031362c20382c202e2e2e2929000003e8003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a063005000a646f75626c65546573743efc000000" # noqa: E501 @@ -1084,8 +486,8 @@ def check_equality(self: object, other: object) -> bool: if type(self) != type(other): return False if isinstance(self, dict): - self = cast("dict[Any, Any]", self) - other = cast("dict[Any, Any]", other) + self = cast(Dict[Any, Any], self) + other = cast(Dict[Any, Any], other) if len(self) != len(other): return False for key in self: @@ -1095,8 +497,8 @@ def check_equality(self: object, other: object) -> bool: return False return True if isinstance(self, list): - self = cast("list[Any]", self) - other = cast("list[Any]", other) + self = cast(List[Any], self) + other = cast(List[Any], other) if len(self) != len(other): return False return all(check_equality(self[i], other[i]) for i in range(len(self))) @@ -1112,8 +514,6 @@ def check_equality(self: object, other: object) -> bool: # endregion # region Edge cases - - def test_from_object_morecases(): """Test from_object with more edge cases.""" @@ -1199,7 +599,7 @@ def to_nbt(self, name: str = "") -> NBTag: ), ([1], [], ValueError, "The schema and the data must have the same length."), # Schema is a dict, data is not - (["test"], {"test": ByteNBT}, TypeError, "Expected a dictionary, but found list."), + (["test"], {"test": ByteNBT}, TypeError, "Expected a dictionary, but found a different type."), # Schema is not a dict, list or subclass of NBTagConvertible ( ["test"], @@ -1246,7 +646,7 @@ def to_nbt(self, name: str = "") -> NBTag: {"test": object()}, ByteNBT, TypeError, - r"Expected one of \(bytes, str, int, float, list\), but found object.", + "Expected one of \\(bytes, str, int, float, list\\), but found object.", ), # The data is a list but not all elements are ints ( @@ -1390,6 +790,3 @@ def test_wrong_type(buffer_content: str, tag_type: type[NBTag]): buffer = Buffer(bytearray.fromhex(buffer_content)) with pytest.raises(TypeError): tag_type.read_from(buffer, with_name=False) - - -# endregion diff --git a/tests/mcproto/types/test_quaternion.py b/tests/mcproto/types/test_quaternion.py new file mode 100644 index 00000000..096e685b --- /dev/null +++ b/tests/mcproto/types/test_quaternion.py @@ -0,0 +1,121 @@ +from __future__ import annotations +import struct +from typing import cast +import pytest +import math +from mcproto.types.quaternion import Quaternion +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=Quaternion, + fields=[("x", float), ("y", float), ("z", float), ("w", float)], + test_data=[ + ((0.0, 0.0, 0.0, 0.0), struct.pack(">ffff", 0.0, 0.0, 0.0, 0.0)), + ((-1.0, -1.0, -1.0, -1.0), struct.pack(">ffff", -1.0, -1.0, -1.0, -1.0)), + ((1.0, 2.0, 3.0, 4.0), struct.pack(">ffff", 1.0, 2.0, 3.0, 4.0)), + ((1.5, 2.5, 3.5, 4.5), struct.pack(">ffff", 1.5, 2.5, 3.5, 4.5)), + # Invalid values + ((1.0, 2.0, "3.0", 4.0), TypeError), + ((float("nan"), 2.0, 3.0, 4.0), ValueError), + ((1.0, float("inf"), 3.0, 4.0), ValueError), + ((1.0, 2.0, -float("inf"), 4.0), ValueError), + ], +) + + +def test_quaternion_addition(): + """Test that two Quaternion objects can be added together (resulting in a new Quaternion object).""" + v1 = Quaternion(x=1.0, y=2.0, z=3.0, w=4.0) + v2 = Quaternion(x=4.5, y=5.25, z=6.125, w=7.0625) + v3 = v1 + v2 + assert type(v3) == Quaternion + assert v3.x == 5.5 + assert v3.y == 7.25 + assert v3.z == 9.125 + assert v3.w == 11.0625 + + +def test_quaternion_subtraction(): + """Test that two Quaternion objects can be subtracted (resulting in a new Quaternion object).""" + v1 = Quaternion(x=1.0, y=2.0, z=3.0, w=4.0) + v2 = Quaternion(x=4.5, y=5.25, z=6.125, w=7.0625) + v3 = v2 - v1 + assert type(v3) == Quaternion + assert v3.x == 3.5 + assert v3.y == 3.25 + assert v3.z == 3.125 + assert v3.w == 3.0625 + + +def test_quaternion_negative(): + """Test that a Quaternion object can be negated.""" + v1 = Quaternion(x=1.0, y=2.5, z=3.0, w=4.5) + v2 = -v1 + assert type(v2) == Quaternion + assert v2.x == -1.0 + assert v2.y == -2.5 + assert v2.z == -3.0 + assert v2.w == -4.5 + + +def test_quaternion_multiplication_int(): + """Test that a Quaternion object can be multiplied by an integer.""" + v1 = Quaternion(x=1.0, y=2.25, z=3.0, w=4.5) + v2 = v1 * 2 + assert v2.x == 2.0 + assert v2.y == 4.5 + assert v2.z == 6.0 + assert v2.w == 9.0 + + +def test_quaternion_multiplication_float(): + """Test that a Quaternion object can be multiplied by a float.""" + v1 = Quaternion(x=2.0, y=4.5, z=6.0, w=9.0) + v2 = v1 * 1.5 + assert type(v2) == Quaternion + assert v2.x == 3.0 + assert v2.y == 6.75 + assert v2.z == 9.0 + assert v2.w == 13.5 + + +def test_quaternion_norm_squared(): + """Test that the squared norm of a Quaternion object can be calculated.""" + v = Quaternion(x=3.0, y=4.0, z=5.0, w=6.0) + assert v.norm_squared() == 86.0 + + +def test_quaternion_norm(): + """Test that the norm of a Quaternion object can be calculated.""" + v = Quaternion(x=3.0, y=4.0, z=5.0, w=6.0) + assert (v.norm() - 86.0**0.5) < 1e-6 + + +@pytest.mark.parametrize( + ("x", "y", "z", "w", "expected"), + [ + (0, 0, 0, 0, ZeroDivisionError), + (1, 0, 0, 0, Quaternion(x=1, y=0, z=0, w=0)), + (0, 1, 0, 0, Quaternion(x=0, y=1, z=0, w=0)), + (0, 0, 1, 0, Quaternion(x=0, y=0, z=1, w=0)), + (0, 0, 0, 1, Quaternion(x=0, y=0, z=0, w=1)), + (1, 1, 1, 1, Quaternion(x=1, y=1, z=1, w=1) / math.sqrt(4)), + (-1, -1, -1, -1, Quaternion(x=-1, y=-1, z=-1, w=-1) / math.sqrt(4)), + ], +) +def test_quaternion_normalize(x: float, y: float, z: float, w: float, expected: Quaternion | type): + """Test that a Quaternion object can be normalized.""" + v = Quaternion(x=x, y=y, z=z, w=w) + if isinstance(expected, type): + expected = cast(type[Exception], expected) + with pytest.raises(expected): + v.normalize() + else: + assert (v.normalize() - expected).norm() < 1e-6 + + +def test_quaternion_tuple(): + """Test that a Quaternion object can be converted to a tuple.""" + v = Quaternion(x=1.0, y=2.0, z=3.0, w=4.0) + assert v.to_tuple() == (1.0, 2.0, 3.0, 4.0) diff --git a/tests/mcproto/types/test_slot.py b/tests/mcproto/types/test_slot.py new file mode 100644 index 00000000..7db93de6 --- /dev/null +++ b/tests/mcproto/types/test_slot.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from mcproto.types.nbt import ByteNBT, CompoundNBT, EndNBT, IntNBT, NBTag +from mcproto.types.slot import Slot +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=Slot, + fields=[("present", bool), ("item_id", int), ("item_count", int), ("nbt", NBTag)], + test_data=[ + ((False, None, None, None), b"\x00"), + ((True, 1, 1, None), b"\x01\x01\x01\x00"), # EndNBT() is automatically added + ((True, 1, 1, EndNBT()), b"\x01\x01\x01\x00"), + ( + (True, 2, 3, CompoundNBT([IntNBT(4, "int_nbt"), ByteNBT(5, "byte_nbt")])), + b"\x01\x02\x03" + CompoundNBT([IntNBT(4, "int_nbt"), ByteNBT(5, "byte_nbt")]).serialize(), + ), + # Present but no item_id + ((True, None, 1, None), ValueError), + # Present but no item_count + ((True, 1, None, None), ValueError), + # Present but the NBT has the wrong type + ((True, 1, 1, IntNBT(1, "int_nbt")), TypeError), + # Not present but item_id is present + ((False, 1, 1, None), ValueError), + # Not present but item_count is present + ((False, None, 1, None), ValueError), + # Not present but NBT is present + ((False, None, None, CompoundNBT([])), ValueError), + ], +) diff --git a/tests/mcproto/types/test_vec3.py b/tests/mcproto/types/test_vec3.py new file mode 100644 index 00000000..01203d08 --- /dev/null +++ b/tests/mcproto/types/test_vec3.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import struct +from typing import cast +import pytest +import math + +from mcproto.types.vec3 import Position, Vec3 +from tests.helpers import gen_serializable_test + +gen_serializable_test( + context=globals(), + cls=Position, + fields=[("x", int), ("y", int), ("z", int)], + test_data=[ + ((0, 0, 0), b"\x00\x00\x00\x00\x00\x00\x00\x00"), + ((-1, -1, -1), b"\xff\xff\xff\xff\xff\xff\xff\xff"), + # from https://wiki.vg/Protocol#Position + ( + (18357644, 831, -20882616), + bytes([0b01000110, 0b00000111, 0b01100011, 0b00101100, 0b00010101, 0b10110100, 0b10000011, 0b00111111]), + ), + # X out of bounds + ((1 << 25, 0, 0), OverflowError), + ((-(1 << 25) - 1, 0, 0), OverflowError), + # Y out of bounds + ((0, 1 << 11, 0), OverflowError), + ((0, -(1 << 11) - 1, 0), OverflowError), + # Z out of bounds + ((0, 0, 1 << 25), OverflowError), + ((0, 0, -(1 << 25) - 1), OverflowError), + ], +) + +gen_serializable_test( + context=globals(), + cls=Vec3, + fields=[("x", float), ("y", float), ("z", float)], + test_data=[ + ((0.0, 0.0, 0.0), struct.pack(">fff", 0.0, 0.0, 0.0)), + ((-1.0, -1.0, -1.0), struct.pack(">fff", -1.0, -1.0, -1.0)), + ((1.0, 2.0, 3.0), struct.pack(">fff", 1.0, 2.0, 3.0)), + ((1.5, 2.5, 3.5), struct.pack(">fff", 1.5, 2.5, 3.5)), + # Invalid values + ((1.0, 2.0, "3.0"), TypeError), + ((float("nan"), 2.0, 3.0), ValueError), + ((1.0, float("inf"), 3.0), ValueError), + ((1.0, 2.0, -float("inf")), ValueError), + ], +) + + +def test_position_addition(): + """Test that two Position objects can be added together (resuling in a new Position object).""" + p1 = Position(x=1, y=2, z=3) + p2 = Position(x=4, y=5, z=6) + p3 = p1 + p2 + assert type(p3) == Position + assert p3.x == 5 + assert p3.y == 7 + assert p3.z == 9 + + +def test_position_subtraction(): + """Test that two Position objects can be subtracted (resuling in a new Position object).""" + p1 = Position(x=1, y=2, z=3) + p2 = Position(x=2, y=4, z=6) + p3 = p2 - p1 + assert type(p3) == Position + assert p3.x == 1 + assert p3.y == 2 + assert p3.z == 3 + + +def test_position_negative(): + """Test that a Position object can be negated.""" + p1 = Position(x=1, y=2, z=3) + p2 = -p1 + assert type(p2) == Position + assert p2.x == -1 + assert p2.y == -2 + assert p2.z == -3 + + +def test_position_multiplication_int(): + """Test that a Position object can be multiplied by an integer.""" + p1 = Position(x=1, y=2, z=3) + p2 = p1 * 2 + assert p2.x == 2 + assert p2.y == 4 + assert p2.z == 6 + + +def test_position_multiplication_float(): + """Test that a Position object can be multiplied by a float.""" + p1 = Position(x=2, y=4, z=6) + p2 = p1 * 1.5 + assert type(p2) == Position + assert p2.x == 3 + assert p2.y == 6 + assert p2.z == 9 + + +def test_vec3_to_position(): + """Test that a Vec3 object can be converted to a Position object.""" + v = Vec3(x=1.5, y=2.5, z=3.5) + p = v.to_position() + assert type(p) == Position + assert p.x == 1 + assert p.y == 2 + assert p.z == 3 + + +def test_position_to_vec3(): + """Test that a Position object can be converted to a Vec3 object.""" + p = Position(x=1, y=2, z=3) + v = p.to_vec3() + assert type(v) == Vec3 + assert v.x == 1.0 + assert v.y == 2.0 + assert v.z == 3.0 + + +def test_position_to_tuple(): + """Test that a Position object can be converted to a tuple.""" + p = Position(x=1, y=2, z=3) + t = p.to_tuple() + assert type(t) == tuple + assert t == (1, 2, 3) + + +def test_vec3_addition(): + """Test that two Vec3 objects can be added together (resuling in a new Vec3 object).""" + v1 = Vec3(x=1.0, y=2.0, z=3.0) + v2 = Vec3(x=4.5, y=5.25, z=6.125) + v3 = v1 + v2 + assert type(v3) == Vec3 + assert v3.x == 5.5 + assert v3.y == 7.25 + assert v3.z == 9.125 + + +def test_vec3_subtraction(): + """Test that two Vec3 objects can be subtracted (resuling in a new Vec3 object).""" + v1 = Vec3(x=1.0, y=2.0, z=3.0) + v2 = Vec3(x=4.5, y=5.25, z=6.125) + v3 = v2 - v1 + assert type(v3) == Vec3 + assert v3.x == 3.5 + assert v3.y == 3.25 + assert v3.z == 3.125 + + +def test_vec3_negative(): + """Test that a Vec3 object can be negated.""" + v1 = Vec3(x=1.0, y=2.5, z=3.0) + v2 = -v1 + assert type(v2) == Vec3 + assert v2.x == -1.0 + assert v2.y == -2.5 + assert v2.z == -3.0 + + +def test_vec3_multiplication_int(): + """Test that a Vec3 object can be multiplied by an integer.""" + v1 = Vec3(x=1.0, y=2.25, z=3.0) + v2 = v1 * 2 + assert v2.x == 2.0 + assert v2.y == 4.5 + assert v2.z == 6.0 + + +def test_vec3_multiplication_float(): + """Test that a Vec3 object can be multiplied by a float.""" + v1 = Vec3(x=2.0, y=4.5, z=6.0) + v2 = v1 * 1.5 + assert type(v2) == Vec3 + assert v2.x == 3.0 + assert v2.y == 6.75 + assert v2.z == 9.0 + + +def test_vec3_norm_squared(): + """Test that the squared norm of a Vec3 object can be calculated.""" + v = Vec3(x=3.0, y=4.0, z=5.0) + assert v.norm_squared() == 50.0 + + +def test_vec3_norm(): + """Test that the norm of a Vec3 object can be calculated.""" + v = Vec3(x=3.0, y=4.0, z=5.0) + assert (v.norm() - 50.0**0.5) < 1e-6 + + +@pytest.mark.parametrize( + ("x", "y", "z", "expected"), + [ + (0, 0, 0, ZeroDivisionError), + (1, 0, 0, Vec3(x=1, y=0, z=0)), + (0, 1, 0, Vec3(x=0, y=1, z=0)), + (0, 0, 1, Vec3(x=0, y=0, z=1)), + (1, 1, 1, Vec3(x=1, y=1, z=1) / math.sqrt(3)), + (-1, -1, -1, Vec3(x=-1, y=-1, z=-1) / math.sqrt(3)), + ], +) +def test_vec3_normalize(x: float, y: float, z: float, expected: Vec3 | type): + """Test that a Vec3 object can be normalized.""" + v = Vec3(x=x, y=y, z=z) + if isinstance(expected, type): + expected = cast(type[Exception], expected) + with pytest.raises(expected): + v.normalize() + else: + assert (v.normalize() - expected).norm() < 1e-6