From f2019220055fd133bdeeee4b61523f81d9c17e62 Mon Sep 17 00:00:00 2001 From: Alexis Rossfelder Date: Sat, 27 Apr 2024 23:42:15 +0200 Subject: [PATCH 1/3] Added NBT type - Reading from Buffer object - Writing to Buffer object - Converting from/to python objects 100% test coverage for the nbt.py, Pyright and Ruff are happy --- changes/257.feature.md | 9 + mcproto/types/nbt.py | 1315 +++++++++++++++++++++++++++++++ tests/mcproto/types/test_nbt.py | 1073 +++++++++++++++++++++++++ 3 files changed, 2397 insertions(+) create mode 100644 changes/257.feature.md create mode 100644 mcproto/types/nbt.py create mode 100644 tests/mcproto/types/test_nbt.py diff --git a/changes/257.feature.md b/changes/257.feature.md new file mode 100644 index 00000000..32fefeab --- /dev/null +++ b/changes/257.feature.md @@ -0,0 +1,9 @@ +- Added the `NBTag` to deal with NBT data: + - The `NBTag` class is the base class for all NBT tags and provides the basic functionality to serialize and deserialize NBT data from and to a `Buffer` object. + - The classes `EndNBT`, `ByteNBT`, `ShortNBT`, `IntNBT`, `LongNBT`, `FloatNBT`, `DoubleNBT`, `ByteArrayNBT`, `StringNBT`, `ListNBT`, `CompoundNBT`, `IntArrayNBT`and `LongArrayNBT` were added and correspond to the NBT types described in the [NBT specification](https://wiki.vg/NBT#Specification). + - NBT tags can be created using the `NBTag.from_object()` method, which automatically selects the correct tag type based on the object's type and works recursively for lists and dictionaries. + - The `NBTag.to_object()` method can be used to convert an NBT tag back to a Python object. + - The `NBTag.serialize()` can be used to serialize an NBT tag to a new `Buffer` object. + - The `NBTag.deserialize(buffer)` can be used to deserialize an NBT tag from a `Buffer` object. + - If the buffer already exists, the `NBTag.write_to(buffer, with_type=True, with_name=True)` method can be used to write the NBT tag to the buffer (and in that case with the type and name in the right format). + - The `NBTag.read_from(buffer, with_type=True, with_name=True)` method can be used to read an NBT tag from the buffer (and in that case with the type and name in the right format). diff --git a/mcproto/types/nbt.py b/mcproto/types/nbt.py new file mode 100644 index 00000000..e8a057bb --- /dev/null +++ b/mcproto/types/nbt.py @@ -0,0 +1,1315 @@ +from __future__ import annotations + +import warnings +from abc import ABCMeta +from enum import IntEnum +from typing import ClassVar, List, Mapping, Union, cast + +from typing_extensions import TypeAlias + +from mcproto.buffer import Buffer +from mcproto.protocol.base_io import StructFormat +from mcproto.types.abc import MCType + +__all__ = [ + "NBTagType", + "NBTag", + "EndNBT", + "ByteNBT", + "ShortNBT", + "IntNBT", + "LongNBT", + "FloatNBT", + "DoubleNBT", + "ByteArrayNBT", + "StringNBT", + "ListNBT", + "CompoundNBT", + "IntArrayNBT", + "LongArrayNBT", +] + + +# region NBT Specification +""" +Source : https://web.archive.org/web/20110723210920/http://www.minecraft.net/docs/NBT.txt + +Named Binary Tag specification + +NBT (Named Binary Tag) is a tag based binary format designed to carry large amounts of binary data with smaller amounts +of additional data. +An NBT file consists of a single GZIPped Named Tag of type TAG_Compound. + +A Named Tag has the following format: + + byte tagType + TAG_String name + [payload] + +The tagType is a single byte defining the contents of the payload of the tag. + +The name is a descriptive name, and can be anything (eg "cat", "banana", "Hello World!"). It has nothing to do with the +tagType. +The purpose for this name is to name tags so parsing is easier and can be made to only look for certain recognized tag +names. +Exception: If tagType is TAG_End, the name is skipped and assumed to be "". + +The [payload] varies by tagType. + +Note that ONLY Named Tags carry the name and tagType data. Explicitly identified Tags (such as TAG_String above) only +contains the payload. + + +The tag types and respective payloads are: + + TYPE: 0 NAME: TAG_End + Payload: None. + Note: This tag is used to mark the end of a list. + Cannot be named! If type 0 appears where a Named Tag is expected, the name is assumed to be "". + (In other words, this Tag is always just a single 0 byte when named, and nothing in all other cases) + + TYPE: 1 NAME: TAG_Byte + Payload: A single signed byte (8 bits) + + TYPE: 2 NAME: TAG_Short + Payload: A signed short (16 bits, big endian) + + TYPE: 3 NAME: TAG_Int + Payload: A signed short (32 bits, big endian) + + TYPE: 4 NAME: TAG_Long + Payload: A signed long (64 bits, big endian) + + TYPE: 5 NAME: TAG_Float + Payload: A floating point value (32 bits, big endian, IEEE 754-2008, binary32) + + TYPE: 6 NAME: TAG_Double + Payload: A floating point value (64 bits, big endian, IEEE 754-2008, binary64) + + TYPE: 7 NAME: TAG_Byte_Array + Payload: TAG_Int length + An array of bytes of unspecified format. The length of this array is bytes + + TYPE: 8 NAME: TAG_String + Payload: TAG_Short length + An array of bytes defining a string in UTF-8 format. The length of this array is bytes + + TYPE: 9 NAME: TAG_List + Payload: TAG_Byte tagId + TAG_Int length + A sequential list of Tags (not Named Tags), of type . The length of this array is Tags + Notes: All tags share the same type. + + TYPE: 10 NAME: TAG_Compound + Payload: A sequential list of Named Tags. This array keeps going until a TAG_End is found. + TAG_End end + Notes: If there's a nested TAG_Compound within this tag, that one will also have a TAG_End, so simply reading + until the next TAG_End will not work. + The names of the named tags have to be unique within each TAG_Compound + The order of the tags is not guaranteed. + + + // NEW TAGS + TYPE: 11 NAME: TAG_Int_Array + Payload: TAG_Int length + An array of integers. The length of this array is integers + + TYPE: 12 NAME: TAG_Long_Array + Payload: TAG_Int length + An array of longs. The length of this array is longs + + +""" +# endregion +# region NBT base classes/types + + +class NBTagType(IntEnum): + """Types of NBT tags.""" + + END = 0 + BYTE = 1 + SHORT = 2 + INT = 3 + LONG = 4 + FLOAT = 5 + DOUBLE = 6 + BYTE_ARRAY = 7 + STRING = 8 + LIST = 9 + COMPOUND = 10 + INT_ARRAY = 11 + LONG_ARRAY = 12 + + +PayloadType: TypeAlias = Union[ + int, + float, + bytearray, + bytes, + str, + List["PayloadType"], + Mapping[str, "PayloadType"], + List[int], + "NBTag", + List["NBTag"], +] + + +class _MetaNBTag(ABCMeta): + """Metaclass for NBT tags.""" + + TYPE: NBTagType = NBTagType.COMPOUND + + def __new__(cls, name: str, bases: tuple[type], namespace: dict, **kwargs): + new_cls: NBTag = super().__new__(cls, name, bases, namespace) # type: ignore + if name != "NBTag": + NBTag.ASSOCIATED_TYPES[new_cls.TYPE] = new_cls # type: ignore + return new_cls + + +class NBTag(MCType, metaclass=_MetaNBTag): + """Base class for NBT tags.""" + + __slots__ = ("name", "payload") + + TYPE: ClassVar[NBTagType] = NBTagType.COMPOUND + + ASSOCIATED_TYPES: ClassVar[dict[NBTagType, type[NBTag]]] = {} + + def __init__(self, payload: PayloadType, name: str = ""): + if self.__class__ == NBTag: + raise TypeError("Cannot instantiate an NBTag object directly, use a subclass instead.") + self.name = name + self.payload = payload + + def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: + """Serialize the NBT tag to a buffer. + + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: These parameters only control the first level of serialization. + :return: The buffer containing the serialized NBT tag. + """ + buf = Buffer() + self.write_to(buf, with_name=with_name, with_type=with_type) + return buf + + def _write_header(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> bool: + if with_type: + buf.write_value(StructFormat.BYTE, self.TYPE.value) + if self.TYPE == NBTagType.END: + return False + if with_name: + if not self.name: + raise ValueError("Named tags must have a name.") + StringNBT(self.name).write_to(buf, with_type=False, with_name=False) + return True + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the NBT tag to the buffer.""" + ... + + @classmethod + def deserialize(cls, buf: Buffer, with_name: bool = True, with_type: bool = True) -> NBTag: + """Deserialize the NBT tag. + + :param buf: The buffer to read from. + :param with_name: Whether to read the name of the tag. If set to False, the tag will have the name "". + :param with_type: Whether to read the type of the tag. + If set to True, the tag will be read from the buffer and + the return type will be inferred from the tag type. + If set to False and called from a subclass, the tag type will be inferred from the subclass. + If set to False and called from the base class, the tag type will be TAG_Compound. + + If with_type is set to False, the buffer must not start with the tag type byte. + + :return: The deserialized NBT tag. + """ + name, tag_type = cls._read_header(buf, with_name=with_name, read_type=with_type) + + tag_class = NBTag.ASSOCIATED_TYPES[tag_type] + if cls not in (NBTag, tag_class): + raise TypeError(f"Expected a {cls.__name__} tag, but found a different tag ({tag_class.__name__}).") + + tag = tag_class.read_from(buf, with_type=False, with_name=False) + tag.name = name + return tag + + @classmethod + def _read_header(cls, buf: Buffer, read_type: bool = True, with_name: bool = True) -> tuple[str, NBTagType]: + """Read the header of the NBT tag. + + :param buf: The buffer to read from. + :param read_type: Whether to read the type of the tag from the buffer. + :param with_name: Whether to read the name of the tag. If set to False, the tag will have the name "". + + :return: A tuple containing the name and the tag type. + + + :note: It is possible that this function reads nothing from the buffer if both with_name and read_type are set + to False. + """ + tag_type: NBTagType = cls.TYPE # default value + if read_type: + try: + tag_type = NBTagType(buf.read_value(StructFormat.BYTE)) + except OSError: + raise IOError("Buffer is empty.") from None + except ValueError: + raise TypeError("Invalid tag type.") from None + + if tag_type == NBTagType.END: + return "", tag_type + + name = StringNBT.read_from(buf, with_type=False, with_name=False).value if with_name else "" + + return name, tag_type + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> NBTag: + """Read the NBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag. + If set to True, the tag will be read from the buffer and + the return type will be inferred from the tag type. + If set to False and called from a subclass, the tag type will be inferred from the subclass. + If set to False and called from the base class, the tag type will be TAG_Compound. + :param with_name: Whether to read the name of the tag. If set to False, the tag will have the name "". + + :return: The NBT tag. + """ + return cls.deserialize(buf, with_name=with_name, with_type=with_type) + + @staticmethod + def from_object(data: object, /, name: str = "", *, use_int_array: bool = True) -> NBTag: # noqa: PLR0911,PLR0912 + """Create an NBT tag from an arbitrary (compatible) Python object. + + :param data: The object to convert to an NBT tag. + :param use_int_array: Whether to use IntArrayNBT and LongArrayNBT for lists of integers. + If set to False, all lists of integers will be considered as ListNBT. + :param name: The name of the resulting tag. Used for recursive calls. + + :return: The NBT tag representing the object. + + :note: The function will attempt to convert the object to an NBT tag in the following way: + - If the object is a dictionary with a single key, the key will be used as the name of the tag. + - If the object is an integer, it will be converted to a ByteNBT, ShortNBT, IntNBT, or LongNBT tag + depending on the value. + - If the object is a list, it will be converted to a ListNBT tag. + - If the object is a dictionary, it will be converted to a CompoundNBT tag. + - If the object is a string, it will be converted to a StringNBT tag. + - If the object is a float, it will be converted to a FloatNBT tag. + - If the object can be serialized to bytes, it will be converted to a ByteArrayNBT tag. + + + - If you want an object to be serialized in a specific way, you can implement: + + ```python + def to_nbt(self, name: str = "") -> NBTag: + ... + ``` + """ + if hasattr(data, "to_nbt"): # For objects that can be converted to NBT + return data.to_nbt(name=name) # type: ignore + + if isinstance(data, int): + if -(1 << 7) <= data < 1 << 7: + return ByteNBT(data, name=name) + if -(1 << 15) <= data < 1 << 15: + return ShortNBT(data, name=name) + if -(1 << 31) <= data < 1 << 31: + return IntNBT(data, name=name) + if -(1 << 63) <= data < 1 << 63: + return LongNBT(data, name=name) + raise ValueError(f"Integer {data} is out of range.") + if isinstance(data, float): + return FloatNBT(data, name=name) + if isinstance(data, str): + return StringNBT(data, name=name) + if isinstance(data, (bytearray, bytes)): + if isinstance(data, bytearray): + data = bytes(data) + return ByteArrayNBT(data, name=name) + if isinstance(data, list): + if not data: + # Type END is used to mark an empty list + return ListNBT([], name=name) + first_type = type(data[0]) + if any(type(item) != first_type for item in data): + raise TypeError("All items in a list must be of the same type.") + + if issubclass(first_type, int) and use_int_array: + # Check the range of the integers in the list + use_int = all(-(1 << 31) <= item < 1 << 31 for item in data) + use_long = all(-(1 << 63) <= item < 1 << 63 for item in data) + if use_int: + return IntArrayNBT(data, name=name) + if not use_long: # Too big to fit in a long, won't fit in a List of Longs either + raise ValueError("Integer list contains values out of range.") + return LongArrayNBT(data, name=name) + return ListNBT([NBTag.from_object(item, use_int_array=use_int_array) for item in data], name=name) + if isinstance(data, dict): + if len(data) == 0: + return CompoundNBT([], name=name) + if len(data) == 1 and name == "": + key, value = next(iter(data.items())) + return NBTag.from_object(value, name=key, use_int_array=use_int_array) + payload = [] + for key, value in data.items(): + tag = NBTag.from_object(value, name=key, use_int_array=use_int_array) + payload.append(tag) + return CompoundNBT(payload, name) + if data is None: + warnings.warn("Converting None to an END tag.", stacklevel=2) + return EndNBT() # Should not be used + + try: + # Check if the object can be converted to bytes + return ByteArrayNBT(bytes(data), name=name) # type: ignore + except (TypeError, ValueError): + pass + raise TypeError(f"Cannot convert object of type {type(data)} to an NBT tag.") + + def to_object(self) -> Mapping[str, PayloadType] | PayloadType: + """Convert the NBT payload to a dictionary.""" + return CompoundNBT(self.payload).to_object() # allow NBTag.to_object to act as a dict + + def __getitem__(self, key: str | int) -> PayloadType: + """Get a tag from the list or compound tag.""" + if self.TYPE not in (NBTagType.LIST, NBTagType.COMPOUND, NBTagType.INT_ARRAY, NBTagType.LONG_ARRAY): + raise TypeError(f"Cannot get a tag by index from a non-LIST or non-COMPOUND tag ({self.TYPE}).") + + if not isinstance(self.payload, list): + raise AttributeError( + f"The payload of the tag is not a list ({self.TYPE}).\n" + "Check that the initialization of the tag is correct." + ) + if not isinstance(key, (str, int)): # type: ignore + raise TypeError("Key must be a string or an integer.") + + if isinstance(key, str): + if self.TYPE != NBTagType.COMPOUND: + raise TypeError(f"Cannot get a tag by name from a non-COMPOUND tag ({self.TYPE}).") + if not all(isinstance(tag, NBTag) for tag in self.payload): + raise AttributeError("The payload of the tag is not a list of NBTag objects.") + for tag in self.payload: + tag = cast(NBTag, tag) + if tag.name == key: + return tag + raise KeyError(f"No tag with the name {key!r} found.") + + # Key is an integer + if key < -len(self.payload) or key >= len(self.payload): + raise IndexError(f"Index {key} out of range.") + return self.payload[key] + + def __repr__(self) -> str: + if self.name: + return f"{self.__class__.__name__}[{self.name!r}]({self.payload!r})" + return f"{self.__class__.__name__}({self.payload!r})" + + def __eq__(self, other: object) -> bool: + """Check equality between two NBT tags.""" + if not isinstance(other, NBTag): + raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") + return self.name == other.name and self.TYPE == other.TYPE and self.payload == other.payload + + def to_nbt(self, name: str = "") -> NBTag: + """Convert the object to an NBT tag. + + ..warning This is already an NBT tag, so it will modify the name of the tag and return itself. + """ + self.name = name + return self + + @property + def value(self) -> PayloadType: + """Get the payload of the NBT tag in a python-friendly format.""" + obj = self.to_object() + if isinstance(obj, dict) and self.name: + return obj[self.name] + return obj + + +# endregion +# region NBT tags types + + +class EndNBT(NBTag): + """Sentinel tag used to mark the end of a TAG_Compound.""" + + TYPE = NBTagType.END + __slots__ = () + + def __init__(self): + """Create a new EndNBT tag.""" + super().__init__(0, name="") + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the EndNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> EndNBT: + """Read the EndNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag. Has no effect on the EndNBT tag. + + :return: The EndNBT tag. + """ + _, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + return EndNBT() + + def to_object(self) -> Mapping[str, PayloadType]: + """Convert the EndNBT tag to a python object. + + :return: An empty dictionary. + """ + return {} + + +class ByteNBT(NBTag): + """NBT tag representing a single byte value, represented as a signed 8-bit integer.""" + + TYPE = NBTagType.BYTE + + __slots__ = () + payload: int + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the ByteNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + if self.payload < -(1 << 7) or self.payload >= 1 << 7: + raise OverflowError("Byte value out of range.") + + buf.write_value(StructFormat.BYTE, self.payload) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteNBT: + """Read the ByteNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The ByteNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 1: + raise IOError("Buffer does not contain enough data to read a byte. (Empty buffer)") + + return ByteNBT(buf.read_value(StructFormat.BYTE), name=name) + + def __int__(self) -> int: + """Get the integer value of the ByteNBT tag.""" + return self.payload + + def to_object(self) -> Mapping[str, int] | int: + """Convert the ByteNBT tag to a python object. + + :return: A dictionary containing the name and the integer value of the tag. If the tag has no name, the value + will be returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + @property + def value(self) -> int: + """Get the integer value of the IntNBT tag.""" + return self.payload + + +class ShortNBT(ByteNBT): + """NBT tag representing a short value, represented as a signed 16-bit integer.""" + + TYPE = NBTagType.SHORT + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the ShortNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The short value is written as a signed 16-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if self.payload < -(1 << 15) or self.payload >= 1 << 15: + raise OverflowError("Short value out of range.") + + buf.write(self.payload.to_bytes(2, "big", signed=True)) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ShortNBT: + """Read the ShortNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The ShortNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 2: + raise IOError("Buffer does not contain enough data to read a short.") + + return ShortNBT(int.from_bytes(buf.read(2), "big", signed=True), name=name) + + +class IntNBT(ByteNBT): + """NBT tag representing an integer value, represented as a signed 32-bit integer.""" + + TYPE = NBTagType.INT + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the IntNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The integer value is written as a signed 32-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if self.payload < -(1 << 31) or self.payload >= 1 << 31: + raise OverflowError("Integer value out of range.") + + # No more messing around with the struct, we want 32 bits of data no matter what + buf.write(self.payload.to_bytes(4, "big", signed=True)) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntNBT: + """Read the IntNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The IntNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 4: + raise IOError("Buffer does not contain enough data to read an int.") + + return IntNBT(int.from_bytes(buf.read(4), "big", signed=True), name=name) + + +class LongNBT(ByteNBT): + """NBT tag representing a long value, represented as a signed 64-bit integer.""" + + TYPE = NBTagType.LONG + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the LongNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The long value is written as a signed 64-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if self.payload < -(1 << 63) or self.payload >= 1 << 63: + raise OverflowError("Long value out of range.") + + # No more messing around with the struct, we want 64 bits of data no matter what + buf.write(self.payload.to_bytes(8, "big", signed=True)) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongNBT: + """Read the LongNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The LongNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 8: + raise IOError("Buffer does not contain enough data to read a long.") + + payload = int.from_bytes(buf.read(8), "big", signed=True) + return LongNBT(payload, name=name) + + +class FloatNBT(NBTag): + """NBT tag representing a floating-point value, represented as a 32-bit IEEE 754-2008 binary32 value.""" + + TYPE = NBTagType.FLOAT + + payload: float + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the FloatNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The float value is written as a 32-bit floating-point value in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + buf.write_value(StructFormat.FLOAT, self.payload) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> FloatNBT: + """Read the FloatNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The FloatNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 4: + raise IOError("Buffer does not contain enough data to read a float.") + + return FloatNBT(buf.read_value(StructFormat.FLOAT), name=name) + + def __float__(self) -> float: + """Get the float value of the FloatNBT tag.""" + return self.payload + + def to_object(self) -> Mapping[str, float] | float: + """Convert the FloatNBT tag to a python object. + + :return: A dictionary containing the name and the float value of the tag. If the tag has no name, the value + will be returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + def __eq__(self, other: object) -> bool: + """Check equality between two FloatNBT tags. + + :param other: The other FloatNBT tag to compare to. + + :return: True if the tags are equal, False otherwise. + + :note: The float values are compared with a small epsilon (1e-6) to account for floating-point errors. + """ + if not isinstance(other, NBTag): + raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") + # Compare the float values with a small epsilon + if not (self.name == other.name and self.TYPE == other.TYPE): + return False + if not isinstance(other, self.__class__): # pragma: no cover + return False # Should not happen if nobody messes with the TYPE attribute + + return abs(self.payload - other.payload) < 1e-6 + + @property + def value(self) -> float: + """Get the float value of the FloatNBT tag.""" + return self.payload + + +class DoubleNBT(FloatNBT): + """NBT tag representing a double-precision floating-point value, represented as a 64-bit IEEE 754-2008 binary64.""" + + TYPE = NBTagType.DOUBLE + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the DoubleNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The double value is written as a 64-bit floating-point value in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + buf.write_value(StructFormat.DOUBLE, self.payload) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> DoubleNBT: + """Read the DoubleNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The DoubleNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 8: + raise IOError("Buffer does not contain enough data to read a double.") + + return DoubleNBT(buf.read_value(StructFormat.DOUBLE), name=name) + + +class ByteArrayNBT(NBTag): + """NBT tag representing an array of bytes. The length of the array is stored as a signed 32-bit integer.""" + + TYPE = NBTagType.BYTE_ARRAY + + __slots__ = () + + payload: bytearray + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the ByteArrayNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The length of the byte array is written as a signed 32-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + IntNBT(len(self.payload)).write_to(buf, with_type=False, with_name=False) + buf.write(self.payload) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteArrayNBT: + """Read the ByteArrayNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The ByteArrayNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + try: + length = IntNBT.read_from(buf, with_type=False, with_name=False).value + except IOError: + raise IOError("Buffer does not contain enough data to read a byte array.") from None + + if length < 0: + raise ValueError("Invalid byte array length.") + + if buf.remaining < length: + raise IOError( + f"Buffer does not contain enough data to read the byte array ({buf.remaining} < {length} bytes)." + ) + + return ByteArrayNBT(buf.read(length), name=name) + + def __bytes__(self) -> bytes: + """Get the bytes value of the ByteArrayNBT tag.""" + return self.payload + + def to_object(self) -> Mapping[str, bytearray] | bytearray: + """Convert the ByteArrayNBT tag to a python object. + + :return: A dictionary containing the name and the byte array value of the tag. If the tag has no name, the + value will be returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + def __repr__(self) -> str: + """Get a string representation of the ByteArrayNBT tag.""" + if self.name: + return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)})" + if len(self.payload) < 8: + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{self.__class__.__name__}(length={len(self.payload)}, {bytes(self.payload[:7])!r}...)" + + @property + def value(self) -> bytearray: + """Get the bytes value of the ByteArrayNBT tag.""" + return self.payload + + +class StringNBT(NBTag): + """NBT tag representing an UTF-8 string value. The length of the string is stored as a signed 16-bit integer.""" + + TYPE = NBTagType.STRING + + __slots__ = () + + payload: str + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the StringNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The length of the string is written as a signed 16-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + if len(self.payload) > 32767: + # Check the length of the string (can't generate strings that long in tests) + raise ValueError("Maximum character limit for writing strings is 32767 characters.") # pragma: no cover + + data = bytearray(self.payload, "utf-8") + ShortNBT(len(data)).write_to(buf, with_type=False, with_name=False) + buf.write(data) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> StringNBT: + """Read the StringNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The StringNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + try: + length = ShortNBT.read_from(buf, with_type=False, with_name=False).value + except IOError: + raise IOError("Buffer does not contain enough data to read a string.") from None + + if length < 0: + raise ValueError("Invalid string length.") + + if buf.remaining < length: + raise IOError("Buffer does not contain enough data to read the string.") + data = buf.read(length) + try: + return StringNBT(data.decode("utf-8"), name=name) + except UnicodeDecodeError: + raise # We want to know it + + def __str__(self) -> str: + """Get the string value of the StringNBT tag.""" + return self.payload + + def to_object(self) -> Mapping[str, str] | str: + """Convert the StringNBT tag to a python object. + + :return: A dictionary containing the name and the string value of the tag. If the tag has no name, the value + will be returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + @property + def value(self) -> str: + """Get the string value of the StringNBT tag.""" + return self.payload + + +class ListNBT(NBTag): + """NBT tag representing a list of tags. All tags in the list must be of the same type.""" + + TYPE = NBTagType.LIST + + __slots__ = () + + payload: list[NBTag] + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the ListNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The tag type of the list is written as a single byte, followed by the length of the list as a signed + 32-bit integer in big-endian format. The tags in the list are then serialized one by one. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if not self.payload: + # Set the tag type to TAG_End if the list is empty + EndNBT().write_to(buf, with_name=False) + IntNBT(0).write_to(buf, with_name=False, with_type=False) + return + + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway + raise ValueError( + f"All items in a list must be NBTags. Got {self.payload!r}.\nUse NBTag.from_object() to convert " + "objects to tags first." + ) + + tag_type = self.payload[0].TYPE + ByteNBT(tag_type).write_to(buf, with_name=False, with_type=False) + IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + for tag in self.payload: + if tag_type != tag.TYPE: + raise ValueError(f"All tags in a list must be of the same type, got tag {tag!r}") + if tag.name != "": + raise ValueError(f"All tags in a list must be unnamed, got tag {tag!r}") + + tag.write_to(buf, with_type=False, with_name=False) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ListNBT: + """Read the ListNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The ListNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + list_tag_type = ByteNBT.read_from(buf, with_type=False, with_name=False).payload + try: + length = IntNBT.read_from(buf, with_type=False, with_name=False).value + except IOError: + raise IOError("Buffer does not contain enough data to read a list.") from None + + if length < 0 or list_tag_type == NBTagType.END: + return ListNBT([], name=name) + + try: + list_tag_type = NBTagType(list_tag_type) + except ValueError: + raise TypeError(f"Unknown tag type {list_tag_type}.") from None + + list_type_class = NBTag.ASSOCIATED_TYPES.get(list_tag_type, NBTag) + if list_type_class == NBTag: + raise TypeError(f"Unknown tag type {list_tag_type}.") # pragma: no cover + try: + payload = [ + # The type is already known, so we don't need to read it again + # List items are unnamed, so we don't need to read the name + list_type_class.read_from(buf, with_type=False, with_name=False) + for _ in range(length) + ] + except IOError: + raise IOError("Buffer does not contain enough data to read the list.") from None + return ListNBT(payload, name=name) + + def __iter__(self): + """Iterate over the tags in the list.""" + yield from self.payload + + def __repr__(self) -> str: + """Get a string representation of the ListNBT tag.""" + if self.name: + return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + if len(self.payload) < 8: + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + + def to_object(self) -> Mapping[str, list[PayloadType]] | list[PayloadType]: + """Convert the ListNBT tag to a python object. + + :return: A dictionary containing the name and the list of tags. If the tag has no name, the list will be + returned directly. + """ + self.payload: list[NBTag] + if self.name: + return {self.name: [tag.to_object() for tag in self.payload]} # Extract the (unnamed) object from each tag + return [tag.to_object() for tag in self.payload] # Extract the (unnamed) object from each tag + + +class CompoundNBT(NBTag): + """NBT tag representing a compound of named tags.""" + + TYPE = NBTagType.COMPOUND + + __slots__ = () + + payload: list[NBTag] + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the CompoundNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. THis only affects the name of + the compound tag itself, not the names of the tags inside the compound. + + :note: The tags in the compound are serialized one by one, followed by an EndNBT tag. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + if not self.payload: + EndNBT().write_to(buf, with_name=False, with_type=True) + return + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway + raise ValueError( + f"All items in a compound must be NBTags. Got {self.payload!r}.\n" + "Use NBTag.from_object() to convert objects to tags first." + ) + + if not all(tag.name for tag in self.payload): + raise ValueError(f"All tags in a compound must be named, got tags {self.payload!r}") + + if len(self.payload) != len({tag.name for tag in self.payload}): # Check for duplicate names + raise ValueError("All tags in a compound must have unique names.") + + for tag in self.payload: + tag.write_to(buf) + EndNBT().write_to(buf, with_name=False, with_type=True) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> CompoundNBT: + """Read the CompoundNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The CompoundNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + payload = [] + while True: + child_name, child_type = cls._read_header(buf, with_name=True, read_type=True) + if child_type == NBTagType.END: + break + # The name and type of the tag have already been read + tag = NBTag.ASSOCIATED_TYPES[child_type].read_from(buf, with_type=False, with_name=False) + tag.name = child_name + payload.append(tag) + return CompoundNBT(payload, name=name) + + def __iter__(self): + """Iterate over the tags in the compound.""" + for tag in self.payload: + yield tag.name, tag + + def __repr__(self) -> str: + """Get a string representation of the CompoundNBT tag.""" + if self.name: + return f"{self.__class__.__name__}[{self.name!r}]({dict(self)})" + return f"{self.__class__.__name__}({dict(self)})" + + def to_object(self) -> Mapping[str, Mapping[str, PayloadType]]: + """Convert the CompoundNBT tag to a python object. + + :return: A dictionary containing the name and the dictionary of tags. If the tag has no name, the dictionary + will be returned directly. + """ + result = {} + for tag in self.payload: + if tag.name in result: + raise ValueError(f"Duplicate tag name {tag.name!r} in the compound.") + if tag.name == "": + raise ValueError("All tags in a compound must have a name.") + result.update(cast("dict[str, PayloadType]", tag.to_object())) + if self.name: + return {self.name: result} + return result + + def __eq__(self, other: object) -> bool: + """Check equality between two CompoundNBT tags. + + :param other: The other CompoundNBT tag to compare to. + + :return: True if the tags are equal, False otherwise. + + :note: The order of the tags is not guaranteed, but the names of the tags must match. This function assumes + that there are no duplicate tags in the compound. + """ + # The order of the tags is not guaranteed + if not isinstance(other, NBTag): + raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") + if self.name != other.name or self.TYPE != other.TYPE: + return False + if not isinstance(other, self.__class__): # pragma: no cover + return False # Should not happen if nobody messes with the TYPE attribute + if len(self.payload) != len(other.payload): + return False + return all(tag in other.payload for tag in self.payload) + + +class IntArrayNBT(NBTag): + """NBT tag representing an array of integers. The length of the array is stored as a signed 32-bit integer.""" + + TYPE = NBTagType.INT_ARRAY + + __slots__ = () + + payload: list[int] + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the IntArrayNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The length of the integer array is written as a signed 32-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway + raise ValueError("All items in an integer array must be integers.") + + if any(item < -(1 << 31) or item >= 1 << 31 for item in self.payload): + raise OverflowError("Integer array contains values out of range.") + + IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + for i in self.payload: + IntNBT(i).write_to(buf, with_name=False, with_type=False) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntArrayNBT: + """Read the IntArrayNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The IntArrayNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != NBTagType.INT_ARRAY: + raise TypeError(f"Expected an INT_ARRAY tag, but found a different tag ({tag_type}).") + length = IntNBT.read_from(buf, with_type=False, with_name=False).value + try: + payload = [IntNBT.read_from(buf, with_type == NBTagType.INT, with_name=False).value for _ in range(length)] + except IOError: + raise IOError( + "Buffer does not contain enough data to read the entire integer array. (Incomplete data)" + ) from None + return IntArrayNBT(payload, name=name) + + def __repr__(self) -> str: + """Get a string representation of the IntArrayNBT tag.""" + if self.name: + return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + if len(self.payload) < 8: + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + + def __iter__(self): + """Iterate over the integers in the array.""" + yield from self.payload + + def to_object(self) -> Mapping[str, list[int]] | list[int]: + """Convert the IntArrayNBT tag to a python object. + + :return: A dictionary containing the name and the list of integers. If the tag has no name, the list will be + returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + @property + def value(self) -> list[int]: + """Get the list of integers in the IntArrayNBT tag.""" + return self.payload + + +class LongArrayNBT(IntArrayNBT): + """NBT tag representing an array of longs. The length of the array is stored as a signed 32-bit integer.""" + + TYPE = NBTagType.LONG_ARRAY + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the LongArrayNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The length of the long array is written as a signed 32-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway + raise ValueError(f"All items in a long array must be integers. ({self.payload})") + + if any(item < -(1 << 63) or item >= 1 << 63 for item in self.payload): + raise OverflowError(f"Long array contains values out of range. ({self.payload})") + + IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + for i in self.payload: + LongNBT(i).write_to(buf, with_name=False, with_type=False) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongArrayNBT: + """Read the LongArrayNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The LongArrayNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != NBTagType.LONG_ARRAY: + raise TypeError(f"Expected a LONG_ARRAY tag, but found a different tag ({tag_type}).") + length = IntNBT.read_from(buf, with_type=False, with_name=False).payload + + try: + payload = [LongNBT.read_from(buf, with_type=False, with_name=False).payload for _ in range(length)] + except IOError: + raise IOError( + "Buffer does not contain enough data to read the entire long array. (Incomplete data)" + ) from None + return LongArrayNBT(payload, name=name) + + +# endregion diff --git a/tests/mcproto/types/test_nbt.py b/tests/mcproto/types/test_nbt.py new file mode 100644 index 00000000..18e5b37c --- /dev/null +++ b/tests/mcproto/types/test_nbt.py @@ -0,0 +1,1073 @@ +from __future__ import annotations + +import struct + +import pytest + +from mcproto.buffer import Buffer +from mcproto.types.nbt import ( + ByteArrayNBT, + ByteNBT, + CompoundNBT, + DoubleNBT, + EndNBT, + FloatNBT, + IntArrayNBT, + IntNBT, + ListNBT, + LongArrayNBT, + LongNBT, + NBTag, + NBTagType, + PayloadType, + ShortNBT, + StringNBT, +) + +# region EndNBT + + +def test_serialize_deserialize_end(): + """Test serialization/deserialization of NBT END tag.""" + output_bytes = EndNBT().serialize() + assert output_bytes == bytearray.fromhex("00") + + buffer = Buffer() + EndNBT().write_to(buffer) + assert buffer == bytearray.fromhex("00") + + buffer.clear() + EndNBT().write_to(buffer, with_name=False) + assert buffer == bytearray.fromhex("00") + + buffer = Buffer(bytearray.fromhex("00")) + assert NBTag.deserialize(buffer).TYPE == NBTagType.END + + +# endregion +# region Numerical NBT tests + + +@pytest.mark.parametrize( + ("nbt_class", "value", "expected_bytes"), + [ + (ByteNBT, 0, bytearray.fromhex("01 00")), + (ByteNBT, 1, bytearray.fromhex("01 01")), + (ByteNBT, 127, bytearray.fromhex("01 7F")), + (ByteNBT, -128, bytearray.fromhex("01 80")), + (ByteNBT, -1, bytearray.fromhex("01 FF")), + (ByteNBT, 12, bytearray.fromhex("01 0C")), + (ShortNBT, 0, bytearray.fromhex("02 00 00")), + (ShortNBT, 1, bytearray.fromhex("02 00 01")), + (ShortNBT, 32767, bytearray.fromhex("02 7F FF")), + (ShortNBT, -32768, bytearray.fromhex("02 80 00")), + (ShortNBT, -1, bytearray.fromhex("02 FF FF")), + (ShortNBT, 12, bytearray.fromhex("02 00 0C")), + (IntNBT, 0, bytearray.fromhex("03 00 00 00 00")), + (IntNBT, 1, bytearray.fromhex("03 00 00 00 01")), + (IntNBT, 2147483647, bytearray.fromhex("03 7F FF FF FF")), + (IntNBT, -2147483648, bytearray.fromhex("03 80 00 00 00")), + (IntNBT, -1, bytearray.fromhex("03 FF FF FF FF")), + (IntNBT, 12, bytearray.fromhex("03 00 00 00 0C")), + (LongNBT, 0, bytearray.fromhex("04 00 00 00 00 00 00 00 00")), + (LongNBT, 1, bytearray.fromhex("04 00 00 00 00 00 00 00 01")), + (LongNBT, (1 << 63) - 1, bytearray.fromhex("04 7F FF FF FF FF FF FF FF")), + (LongNBT, -(1 << 63), bytearray.fromhex("04 80 00 00 00 00 00 00 00")), + (LongNBT, -1, bytearray.fromhex("04 FF FF FF FF FF FF FF FF")), + (LongNBT, 12, bytearray.fromhex("04 00 00 00 00 00 00 00 0C")), + (FloatNBT, 1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 1.0))), + (FloatNBT, 3.14, bytearray.fromhex("05") + bytes(struct.pack(">f", 3.14))), + (FloatNBT, -1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", -1.0))), + (FloatNBT, 12.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 12.0))), + (DoubleNBT, 1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 1.0))), + (DoubleNBT, 3.14, bytearray.fromhex("06") + bytes(struct.pack(">d", 3.14))), + (DoubleNBT, -1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", -1.0))), + (DoubleNBT, 12.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 12.0))), + (ByteArrayNBT, b"", bytearray.fromhex("07 00 00 00 00")), + (ByteArrayNBT, b"\x00", bytearray.fromhex("07 00 00 00 01") + b"\x00"), + (ByteArrayNBT, b"\x00\x01", bytearray.fromhex("07 00 00 00 02") + b"\x00\x01"), + (ByteArrayNBT, b"\x00\x01\x02", bytearray.fromhex("07 00 00 00 03") + b"\x00\x01\x02"), + (ByteArrayNBT, b"\x00\x01\x02\x03", bytearray.fromhex("07 00 00 00 04") + b"\x00\x01\x02\x03"), + (ByteArrayNBT, b"\xFF" * 1024, bytearray.fromhex("07 00 00 04 00") + b"\xFF" * 1024), + ( + ByteArrayNBT, + bytes((n - 1) * n * 2 % 256 for n in range(256)), + bytearray.fromhex("07 00 00 01 00") + bytes((n - 1) * n * 2 % 256 for n in range(256)), + ), + (StringNBT, "", bytearray.fromhex("08 00 00")), + (StringNBT, "test", bytearray.fromhex("08 00 04") + b"test"), + (StringNBT, "a" * 100, bytearray.fromhex("08 00 64") + b"a" * (100)), + (StringNBT, "&à@é", bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8")), + (ListNBT, [], bytearray.fromhex("09 00 00 00 00 00")), + (ListNBT, [ByteNBT(0)], bytearray.fromhex("09 01 00 00 00 01 00")), + (ListNBT, [ShortNBT(127), ShortNBT(256)], bytearray.fromhex("09 02 00 00 00 02 00 7F 01 00")), + ( + ListNBT, + [ListNBT([ByteNBT(0)]), ListNBT([IntNBT(256)])], + bytearray.fromhex("09 09 00 00 00 02 01 00 00 00 01 00 03 00 00 00 01 00 00 01 00"), + ), + (CompoundNBT, [], bytearray.fromhex("0A 00")), + ( + CompoundNBT, + [ByteNBT(0, name="test")], + bytearray.fromhex("0A") + ByteNBT(0, name="test").serialize() + b"\x00", + ), + ( + CompoundNBT, + [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], + bytearray.fromhex("0A") + ShortNBT(128, "Short").serialize() + ByteNBT(-1, "Byte").serialize() + b"\x00", + ), + ( + CompoundNBT, + [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], + bytearray.fromhex("0A") + CompoundNBT([ByteNBT(0, name="Byte")], name="test").serialize() + b"\x00", + ), + ( + CompoundNBT, + [CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test"), IntNBT(-1, "Int 2")], + bytearray.fromhex("0A") + + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test").serialize() + + IntNBT(-1, "Int 2").serialize() + + b"\x00", + ), + (IntArrayNBT, [], bytearray.fromhex("0B 00 00 00 00")), + (IntArrayNBT, [0], bytearray.fromhex("0B 00 00 00 01 00 00 00 00")), + (IntArrayNBT, [0, 1], bytearray.fromhex("0B 00 00 00 02 00 00 00 00 00 00 00 01")), + (IntArrayNBT, [1, 2, 3], bytearray.fromhex("0B 00 00 00 03 00 00 00 01 00 00 00 02 00 00 00 03")), + (IntArrayNBT, [(1 << 31) - 1], bytearray.fromhex("0B 00 00 00 01 7F FF FF FF")), + (IntArrayNBT, [(1 << 31) - 1, (1 << 31) - 2], bytearray.fromhex("0B 00 00 00 02 7F FF FF FF 7F FF FF FE")), + (IntArrayNBT, [-1, -2, -3], bytearray.fromhex("0B 00 00 00 03 FF FF FF FF FF FF FF FE FF FF FF FD")), + (IntArrayNBT, [12] * 1024, bytearray.fromhex("0B 00 00 04 00") + b"\x00\x00\x00\x0C" * 1024), + (LongArrayNBT, [], bytearray.fromhex("0C 00 00 00 00")), + (LongArrayNBT, [0], bytearray.fromhex("0C 00 00 00 01 00 00 00 00 00 00 00 00")), + (LongArrayNBT, [0, 1], bytearray.fromhex("0C 00 00 00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01")), + ( + LongArrayNBT, + [1, 2, 3], + bytearray.fromhex( + "0C 00 00 00 03 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03" + ), + ), + (LongArrayNBT, [(1 << 63) - 1], bytearray.fromhex("0C 00 00 00 01 7F FF FF FF FF FF FF FF")), + ( + LongArrayNBT, + [(1 << 63) - 1, (1 << 63) - 2], + bytearray.fromhex("0C 00 00 00 02 7F FF FF FF FF FF FF FF 7F FF FF FF FF FF FF FE"), + ), + ( + LongArrayNBT, + [-1, -2, -3], + bytearray.fromhex( + "0C 00 00 00 03 FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FE FF FF FF FF FF FF FF FD" + ), + ), + (LongArrayNBT, [12] * 1024, bytearray.fromhex("0C 00 00 04 00") + b"\x00\x00\x00\x00\x00\x00\x00\x0C" * 1024), + ], +) +def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType, expected_bytes: bytes): + """Test serialization/deserialization of NBT tag without name.""" + # Test serialization + output_bytes = nbt_class(value).serialize(with_name=False) + output_bytes_no_type = nbt_class(value).serialize(with_type=False, with_name=False) + assert output_bytes == expected_bytes + assert output_bytes_no_type == expected_bytes[1:] + + buffer = Buffer() + nbt_class(value).write_to(buffer, with_name=False) + assert buffer == expected_bytes + + # Test deserialization + buffer = Buffer(expected_bytes) + assert NBTag.deserialize(buffer, with_name=False) == nbt_class(value) + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.deserialize(buffer, with_type=False, with_name=False) == nbt_class(value) + + buffer = Buffer(expected_bytes) + assert nbt_class.read_from(buffer, with_name=False) == nbt_class(value) + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.read_from(buffer, with_type=False, with_name=False) == nbt_class(value) + + +@pytest.mark.parametrize( + ("nbt_class", "value", "name", "expected_bytes"), + [ + (ByteNBT, 0, "test", bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("00")), + (ByteNBT, 1, "a", bytearray.fromhex("01") + b"\x00\x01a" + bytearray.fromhex("01")), + (ByteNBT, 127, "&à@é", bytearray.fromhex("01 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F")), + (ByteNBT, -128, "test", bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("80")), + (ByteNBT, 12, "a" * 100, bytearray.fromhex("01") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("0C")), + (ShortNBT, 0, "test", bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("00 00")), + (ShortNBT, 1, "a", bytearray.fromhex("02") + b"\x00\x01a" + bytearray.fromhex("00 01")), + (ShortNBT, 32767, "&à@é", bytearray.fromhex("02 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF")), + (ShortNBT, -32768, "test", bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("80 00")), + (ShortNBT, 12, "a" * 100, bytearray.fromhex("02") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 0C")), + (IntNBT, 0, "test", bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + (IntNBT, 1, "a", bytearray.fromhex("03") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01")), + ( + IntNBT, + 2147483647, + "&à@é", + bytearray.fromhex("03 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF"), + ), + (IntNBT, -2147483648, "test", bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00")), + ( + IntNBT, + 12, + "a" * 100, + bytearray.fromhex("03") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 0C"), + ), + (LongNBT, 0, "test", bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00 00 00 00")), + (LongNBT, 1, "a", bytearray.fromhex("04") + b"\x00\x01a" + bytearray.fromhex("00 00 00 00 00 00 00 01")), + ( + LongNBT, + (1 << 63) - 1, + "&à@é", + bytearray.fromhex("04 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF FF FF FF FF"), + ), + ( + LongNBT, + -1 << 63, + "test", + bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00 00 00 00 00"), + ), + ( + LongNBT, + 12, + "a" * 100, + bytearray.fromhex("04") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 00 00 00 00 0C"), + ), + (FloatNBT, 1.0, "test", bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 1.0))), + (FloatNBT, 3.14, "a", bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 3.14))), + ( + FloatNBT, + -1.0, + "&à@é", + bytearray.fromhex("05 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0)), + ), + (FloatNBT, 12.0, "test", bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 12.0))), + (DoubleNBT, 1.0, "test", bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 1.0))), + (DoubleNBT, 3.14, "a", bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 3.14))), + ( + DoubleNBT, + -1.0, + "&à@é", + bytearray.fromhex("06 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0)), + ), + (DoubleNBT, 12.0, "test", bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 12.0))), + (ByteArrayNBT, b"", "test", bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + ByteArrayNBT, + b"\x00", + "a", + bytearray.fromhex("07") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00", + ), + ( + ByteArrayNBT, + b"\x00\x01", + "&à@é", + bytearray.fromhex("07 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("00 00 00 02") + b"\x00\x01", + ), + ( + ByteArrayNBT, + b"\x00\x01\x02", + "test", + bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 03") + b"\x00\x01\x02", + ), + ( + ByteArrayNBT, + b"\xFF" * 1024, + "a" * 100, + bytearray.fromhex("07") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 04 00") + b"\xFF" * 1024, + ), + (StringNBT, "", "test", bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 00")), + (StringNBT, "test", "a", bytearray.fromhex("08") + b"\x00\x01a" + bytearray.fromhex("00 04") + b"test"), + ( + StringNBT, + "a" * 100, + "&à@é", + bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("00 64") + b"a" * 100, + ), + ( + StringNBT, + "&à@é", + "test", + bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 06") + bytes("&à@é", "utf-8"), + ), + (ListNBT, [], "test", bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00")), + ( + ListNBT, + [ByteNBT(-1)], + "a", + bytearray.fromhex("09") + b"\x00\x01a" + bytearray.fromhex("01 00 00 00 01 FF"), + ), + ( + ListNBT, + [ShortNBT(127), ShortNBT(256)], + "test", + bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("02 00 00 00 02 00 7F 01 00"), + ), + ( + ListNBT, + [ListNBT([ByteNBT(-1)]), ListNBT([IntNBT(256)])], + "a", + bytearray.fromhex("09") + + b"\x00\x01a" + + bytearray.fromhex("09 00 00 00 02 01 00 00 00 01 FF 03 00 00 00 01 00 00 01 00"), + ), + (CompoundNBT, [], "test", bytearray.fromhex("0A") + b"\x00\x04test" + bytearray.fromhex("00")), + ( + CompoundNBT, + [ByteNBT(0, name="Byte")], + "test", + bytearray.fromhex("0A") + b"\x00\x04test" + ByteNBT(0, name="Byte").serialize() + b"\x00", + ), + ( + CompoundNBT, + [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], + "test", + bytearray.fromhex("0A") + + b"\x00\x04test" + + ShortNBT(128, "Short").serialize() + + ByteNBT(-1, "Byte").serialize() + + b"\x00", + ), + ( + CompoundNBT, + [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], + "test", + bytearray.fromhex("0A") + + b"\x00\x04test" + + CompoundNBT([ByteNBT(0, name="Byte")], "test").serialize() + + b"\x00", + ), + ( + CompoundNBT, + [ListNBT([ByteNBT(0)], name="List")], + "test", + bytearray.fromhex("0A") + b"\x00\x04test" + ListNBT([ByteNBT(0)], name="List").serialize() + b"\x00", + ), + (IntArrayNBT, [], "test", bytearray.fromhex("0B") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + IntArrayNBT, + [0], + "a", + bytearray.fromhex("0B") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00\x00\x00\x00", + ), + ( + IntArrayNBT, + [0, 1], + "&à@é", + bytearray.fromhex("0B 00 06") + + bytes("&à@é", "utf-8") + + bytearray.fromhex("00 00 00 02") + + b"\x00\x00\x00\x00\x00\x00\x00\x01", + ), + ( + IntArrayNBT, + [1, 2, 3], + "test", + bytearray.fromhex("0B") + + b"\x00\x04test" + + bytearray.fromhex("00 00 00 03") + + b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03", + ), + ( + IntArrayNBT, + [(1 << 31) - 1], + "a" * 100, + bytearray.fromhex("0B") + + b"\x00\x64" + + b"a" * 100 + + bytearray.fromhex("00 00 00 01") + + b"\x7F\xFF\xFF\xFF", + ), + (LongArrayNBT, [], "test", bytearray.fromhex("0C") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + LongArrayNBT, + [0], + "a", + bytearray.fromhex("0C") + + b"\x00\x01a" + + bytearray.fromhex("00 00 00 01") + + b"\x00\x00\x00\x00\x00\x00\x00\x00", + ), + ( + LongArrayNBT, + [0, 1], + "&à@é", + bytearray.fromhex("0C 00 06") + + bytes("&à@é", "utf-8") + + bytearray.fromhex("00 00 00 02") + + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + ), + ( + LongArrayNBT, + [1, 2, 3], + "test", + bytearray.fromhex("0C") + + b"\x00\x04test" + + bytearray.fromhex("00 00 00 03") + + bytearray.fromhex("00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03"), + ), + ( + LongArrayNBT, + [(1 << 63) - 1] * 100, + "a" * 100, + bytearray.fromhex("0C") + + b"\x00\x64" + + b"a" * 100 + + bytearray.fromhex("00 00 00 64") + + b"\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF" * 100, + ), + ], +) +def test_serialize_deserialize(nbt_class: type[NBTag], value: PayloadType, name: str, expected_bytes: bytes): + """Test serialization/deserialization of NBT tag with name.""" + # Test serialization + output_bytes = nbt_class(value, name).serialize() + output_bytes_no_type = nbt_class(value, name).serialize(with_type=False) + assert output_bytes == expected_bytes + assert output_bytes_no_type == expected_bytes[1:] + + buffer = Buffer() + nbt_class(value, name).write_to(buffer) + assert buffer == expected_bytes + + # Test deserialization + buffer = Buffer(expected_bytes * 2) + assert buffer.remaining == len(expected_bytes) * 2 + assert NBTag.deserialize(buffer) == nbt_class(value, name=name) + assert buffer.remaining == len(expected_bytes) + assert NBTag.deserialize(buffer) == nbt_class(value, name=name) + assert buffer.remaining == 0 + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.deserialize(buffer, with_type=False) == nbt_class(value, name=name) + + buffer = Buffer(expected_bytes) + assert nbt_class.read_from(buffer) == nbt_class(value, name=name) + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.read_from(buffer, with_type=False) == nbt_class(value, name=name) + + +@pytest.mark.parametrize( + ("nbt_class", "size", "tag"), + [ + (ByteNBT, 8, NBTagType.BYTE), + (ShortNBT, 16, NBTagType.SHORT), + (IntNBT, 32, NBTagType.INT), + (LongNBT, 64, NBTagType.LONG), + ], +) +def test_serialize_deserialize_numerical_fail(nbt_class: type[NBTag], size: int, tag: NBTagType): + """Test serialization/deserialization of NBT NUM tag with invalid value.""" + # Out of bounds + with pytest.raises(OverflowError): + nbt_class(1 << (size - 1)).serialize(with_name=False) + + with pytest.raises(OverflowError): + nbt_class(-(1 << (size - 1)) - 1).serialize(with_name=False) + + with pytest.raises(ValueError): # No name + nbt_class(0, "").serialize() # without with_name=False + + # Deserialization + buffer = Buffer(bytearray([tag.value + 1] + [0] * (size // 8))) + with pytest.raises(TypeError): # Tries to read a nbt_class, but it's one higher + nbt_class.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([tag.value] + [0] * ((size // 8) - 1))) + with pytest.raises(IOError): + nbt_class.read_from(buffer, with_name=False) + + buffer = Buffer(bytearray([tag.value, 0, 0] + [0] * (size // 8))) + assert nbt_class.read_from(buffer, with_name=True) == nbt_class(0) + + +# endregion + +# region FloatNBT + + +def test_serialize_deserialize_float_fail(): + """Test serialization/deserialization of NBT FLOAT tag with invalid value.""" + with pytest.raises(ValueError): + FloatNBT(0, 0).serialize() # type:ignore + + with pytest.raises(struct.error): + FloatNBT("test").serialize(with_name=False) + + with pytest.raises(OverflowError): + FloatNBT(1e39, "test").serialize() + + with pytest.raises(OverflowError): + FloatNBT(-1e39, "test").serialize() + + # Deserialization + buffer = Buffer(bytearray([NBTagType.BYTE] + [0] * 4)) + with pytest.raises(TypeError): # Tries to read a FloatNBT, but it's a ByteNBT + FloatNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.FLOAT, 0, 0, 0])) + with pytest.raises(IOError): + FloatNBT.read_from(buffer, with_name=False) + + +# endregion +# region DoubleNBT + + +def test_serialize_deserialize_double_fail(): + """Test serialization/deserialization of NBT DOUBLE tag with invalid value.""" + with pytest.raises(ValueError): + DoubleNBT(0, 0).serialize() # type: ignore + + with pytest.raises(struct.error): + DoubleNBT("test").serialize(with_name=False) + + # Deserialization + buffer = Buffer(bytearray([0x01] + [0] * 8)) + with pytest.raises(TypeError): # Tries to read a DoubleNBT, but it's a ByteNBT + DoubleNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.DOUBLE, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(IOError): + DoubleNBT.read_from(buffer, with_name=False) + + +# endregion +# region ByteArrayNBT + + +def test_serialize_deserialize_bytearray_fail(): + """Test serialization/deserialization of NBT BYTEARRAY tag with invalid value.""" + with pytest.raises(ValueError): + ByteArrayNBT([], 0).serialize() # type:ignore + + with pytest.raises(ValueError): + ByteArrayNBT(b"test", "").serialize() + + # Deserialization + buffer = Buffer(bytearray([0x01] + [0] * 4)) + with pytest.raises(TypeError): # Tries to read a ByteArrayNBT, but it's a ByteNBT + ByteArrayNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0])) # Missing length bytes + with pytest.raises(IOError): + ByteArrayNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 1])) # Missing data bytes + with pytest.raises(IOError): + ByteArrayNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 2, 0])) # Missing data bytes + with pytest.raises(IOError): + ByteArrayNBT.read_from(buffer, with_name=False) + + # Negative length + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0xFF, 0xFF, 0xFF, 0xFF])) # length = -1 + with pytest.raises(ValueError): + ByteArrayNBT.deserialize(buffer, with_name=False) + + +# endregion +# region StringNBT + + +def test_serialize_deserialize_string_fail(): + """Test serialization/deserialization of NBT STRING tag with invalid value.""" + with pytest.raises(ValueError): + StringNBT("", 0).serialize() # type:ignore + + with pytest.raises(ValueError): + StringNBT("test", "").serialize() + + # Deserialization + buffer = Buffer(bytearray([0x01, 0, 0])) + with pytest.raises(TypeError): # Tries to read a StringNBT, but it's a ByteNBT + StringNBT.deserialize(buffer, with_name=False) + + # Not enough data for the length + buffer = Buffer(bytearray([NBTagType.STRING, 0])) + with pytest.raises(IOError): + StringNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.STRING, 0, 1])) + with pytest.raises(IOError): + StringNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.STRING, 0, 2, 0])) + with pytest.raises(IOError): + StringNBT.read_from(buffer, with_name=False) + + # Negative length + buffer = Buffer(bytearray([NBTagType.STRING, 0xFF, 0xFF])) # length = -1 + with pytest.raises(ValueError): + StringNBT.deserialize(buffer, with_name=False) + + # Invalid UTF-8 + buffer = Buffer(bytearray([NBTagType.STRING, 0, 1, 0xC0, 0x80])) + with pytest.raises(UnicodeDecodeError): + StringNBT.read_from(buffer, with_name=False) + + +# endregion +# region ListNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([ByteNBT(0), IntNBT(0)], ValueError), + ([ByteNBT(0), "test"], ValueError), + ([ByteNBT(0), None], ValueError), + ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags + ([ByteNBT(128), ByteNBT(-1)], OverflowError), # Check for error propagation + ], +) +def test_serialize_list_fail(payload, error): + """Test serialization of NBT LIST tag with invalid value.""" + with pytest.raises(error): + ListNBT(payload, "test").serialize() + + +def test_deserialize_list_fail(): + """Test deserialization of NBT LIST tag with invalid value.""" + # Wrong tag type + buffer = Buffer(bytearray([0x09, 255, 0, 0, 0, 1, 0])) + with pytest.raises(TypeError): + ListNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([0x09, 1, 0, 0, 0, 1])) + with pytest.raises(IOError): + ListNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([0x09, 1, 0, 0, 0])) + with pytest.raises(IOError): + ListNBT.read_from(buffer, with_name=False) + + +# endregion +# region CompoundNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([ByteNBT(0, name="Hello"), IntNBT(0)], ValueError), + ([ByteNBT(0, name="hi"), "test"], ValueError), + ([ByteNBT(0, name="hi"), None], ValueError), + ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags + ([ByteNBT(128, name="Jello"), ByteNBT(-1, name="Bonjour")], OverflowError), # Check for error propagation + ], +) +def test_serialize_compound_fail(payload, error): + """Test serialization of NBT COMPOUND tag with invalid value.""" + with pytest.raises(error): + CompoundNBT(payload, "test").serialize() + + # Double name + with pytest.raises(ValueError): + CompoundNBT([ByteNBT(0, name="test"), ByteNBT(0, name="test")], "comp").serialize() + + +def test_deseialize_compound_fail(): + """Test deserialization of NBT COMPOUND tag with invalid value.""" + # Not enough data + buffer = Buffer(bytearray([NBTagType.COMPOUND, 0x01])) + with pytest.raises(IOError): + CompoundNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.COMPOUND])) + with pytest.raises(IOError): + CompoundNBT.read_from(buffer, with_name=False) + + # Wrong tag type + buffer = Buffer(bytearray([15])) + with pytest.raises(TypeError): + NBTag.deserialize(buffer) + + +def test_to_object_compound(): + """Try a few incorrect CompoundNBT.to_object() calls.""" + comp = CompoundNBT([ByteNBT(0, "test"), ByteNBT(1, "test")]) + with pytest.raises(ValueError): + comp.to_object() # Duplicate name + + comp = CompoundNBT([ByteNBT(0), ByteNBT(1)]) + with pytest.raises(ValueError): + comp.to_object() + + +def test_equality_compound(): + """Test equality of CompoundNBT.""" + comp1 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") + assert comp1 == comp2 + + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2")], "comp") + assert comp1 != comp2 + + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], "comp") + assert comp1 != comp2 + + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp2") + assert comp1 != comp2 + + assert comp1 != ByteNBT(0, name="comp") + + +# endregion +# region IntArrayNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([0, "test"], ValueError), + ([0, None], ValueError), + ([0, 1 << 31], OverflowError), + ([0, -(1 << 31) - 1], OverflowError), + ], +) +def test_serialize_intarray_fail(payload, error): + """Test serialization of NBT INTARRAY tag with invalid value.""" + with pytest.raises(error): + IntArrayNBT(payload, "test").serialize() + + +def test_deserialize_intarray_fail(): + """Test deserialization of NBT INTARRAY tag with invalid value.""" + # Not enough data for 1 element + buffer = Buffer(bytearray([0x0B, 0, 0, 0, 1, 0, 0, 0])) + with pytest.raises(IOError): + IntArrayNBT.deserialize(buffer, with_name=False) + + # Not enough data for the size + buffer = Buffer(bytearray([0x0B, 0, 0, 0])) + with pytest.raises(IOError): + IntArrayNBT.read_from(buffer, with_name=False) + + # Not enough data to start the 2nd element + buffer = Buffer(bytearray([0x0B, 0, 0, 0, 2, 1, 0, 0, 0])) + with pytest.raises(IOError): + IntArrayNBT.read_from(buffer, with_name=False) + + +# endregion +# region LongArrayNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([0, "test"], ValueError), + ([0, None], ValueError), + ([0, 1 << 63], OverflowError), + ([0, -(1 << 63) - 1], OverflowError), + ], +) +def test_serialize_deserialize_longarray_fail(payload, error): + """Test serialization/deserialization of NBT LONGARRAY tag with invalid value.""" + with pytest.raises(error): + LongArrayNBT(payload, "test").serialize() + + +def test_deserialize_longarray_fail(): + """Test deserialization of NBT LONGARRAY tag with invalid value.""" + # Not enough data for 1 element + buffer = Buffer(bytearray([0x0C, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(IOError): + LongArrayNBT.deserialize(buffer, with_name=False) + + # Not enough data for the size + buffer = Buffer(bytearray([0x0C, 0, 0, 0])) + with pytest.raises(IOError): + LongArrayNBT.read_from(buffer, with_name=False) + + # Not enough data to start the 2nd element + buffer = Buffer(bytearray([0x0C, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(IOError): + LongArrayNBT.read_from(buffer, with_name=False) + + +# endregion + +# region NBTag + + +def test_nbt_helloworld(): + """Test serialization/deserialization of a simple NBT tag. + + Source data: https://wiki.vg/NBT#Example. + """ + data = bytearray.fromhex("0a000b68656c6c6f20776f726c640800046e616d65000942616e616e72616d6100") + buffer = Buffer(data) + + expected_object = { + "hello world": { + "name": "Bananrama", + } + } + + data = CompoundNBT.deserialize(buffer) + assert data == NBTag.from_object(expected_object) + assert data.to_object() == expected_object + + +def test_nbt_bigfile(): + """Test serialization/deserialization of a big NBT tag. + + Slighly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. + Source data: https://wiki.vg/NBT#Example. + """ + data = "0a00054c6576656c0400086c6f6e67546573747fffffffffffffff02000973686f7274546573747fff08000a737472696e6754657374002948454c4c4f20574f524c4420544849532049532041205445535420535452494e4720c385c384c39621050009666c6f6174546573743eff1832030007696e74546573747fffffff0a00146e657374656420636f6d706f756e6420746573740a000368616d0800046e616d65000648616d70757305000576616c75653f400000000a00036567670800046e616d6500074567676265727405000576616c75653f00000000000c000f6c6973745465737420286c6f6e672900000005000000000000000b000000000000000c000000000000000d000000000000000e7fffffffffffffff0b000e6c697374546573742028696e7429000000047fffffff7ffffffe7ffffffd7ffffffc0900136c697374546573742028636f6d706f756e64290a000000020800046e616d65000f436f6d706f756e642074616720233004000a637265617465642d6f6e000001265237d58d000800046e616d65000f436f6d706f756e642074616720233104000a637265617465642d6f6e000001265237d58d0001000862797465546573747f07006562797465417272617954657374202874686520666972737420313030302076616c756573206f6620286e2a6e2a3235352b6e2a3729253130302c207374617274696e672077697468206e3d302028302c2036322c2033342c2031362c20382c202e2e2e2929000003e8003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a063005000a646f75626c65546573743efc7b5e00" # noqa: E501 + data = bytearray.fromhex(data) + buffer = Buffer(data) + + expected_object = { + "Level": { + "longTest": 9223372036854775807, + "shortTest": 32767, + "stringTest": "HELLO WORLD THIS IS A TEST STRING ÅÄÖ!", + "floatTest": 0.4982314705848694, + "intTest": 2147483647, + "nested compound test": { + "ham": {"name": "Hampus", "value": 0.75}, + "egg": {"name": "Eggbert", "value": 0.5}, + }, + "listTest (long)": [11, 12, 13, 14, 9223372036854775807], + "listTest (int)": [2147483647, 2147483646, 2147483645, 2147483644], + "listTest (compound)": [ + {"name": "Compound tag #0", "created-on": 1264099775885}, + {"name": "Compound tag #1", "created-on": 1264099775885}, + ], + "byteTest": 127, + "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100" + ", starting with n=0 (0, 62, 34, 16, 8, ...))": bytearray( + (n * n * 255 + n * 7) % 100 for n in range(1000) + ), + "doubleTest": 0.4931287132182315, + } + } + + data = CompoundNBT.deserialize(buffer) + # print(f"{data=}\n{expected_object=}\n{data.to_object()=}\n{NBTag.from_object(expected_object)=}") + + def check_equality(self, other): + """Check if two objects are equal, with deep epsilon check for floats.""" + if type(self) != type(other): + return False + if isinstance(self, dict): + if len(self) != len(other): + return False + for key in self: + if key not in other: + return False + if not check_equality(self[key], other[key]): + return False + return True + if isinstance(self, list): + if len(self) != len(other): + return False + return all(check_equality(self[i], other[i]) for i in range(len(self))) + if isinstance(self, float): + return abs(self - other) < 1e-6 + if self != other: + return False + return self == other + + assert data == NBTag.from_object(expected_object) + assert check_equality(data.to_object(), expected_object) + + +# endregion +# region Edge cases + + +def test_from_object_lst_not_same_type(): + """Test from_object with a list that does not have the same type.""" + with pytest.raises(TypeError): + NBTag.from_object([ByteNBT(0), IntNBT(0)]) + + +def test_from_object_out_of_bounds(): + """Test from_object with a value that is out of bounds.""" + with pytest.raises(ValueError): + NBTag.from_object({"test": 1 << 63}) + + with pytest.raises(ValueError): + NBTag.from_object({"test": -(1 << 63) - 1}) + + with pytest.raises(ValueError): + NBTag.from_object({"test": [1 << 63]}) + + with pytest.raises(ValueError): + NBTag.from_object({"test": [-(1 << 63) - 1]}) + + +def test_from_object_morecases(): + """Test from_object with more edge cases.""" + + class CustomType: + def __bytes__(self): + return b"test" + + assert NBTag.from_object( + { + "nbtag": ByteNBT(0), # ByteNBT + "bytearray": b"test", # Conversion from bytes + "empty_list": [], # Empty list with type EndNBT + "empty_compound": {}, # Empty compound + "end_NBTag": None, # Should not be done in practice, would create a broken buffer if serialized + "custom": CustomType(), # Custom type with __bytes__ method + } + ) == CompoundNBT( + [ # Order is shuffled because the spec does not require a specific order + CompoundNBT([], "empty_compound"), + ByteArrayNBT(b"test", "bytearray"), + ByteArrayNBT(b"test", "custom"), + ListNBT([], "empty_list"), + ByteNBT(0, "nbtag"), + EndNBT(), + ] + ) + + # Not a valid object + with pytest.raises(TypeError): + NBTag.from_object({"test": object()}) + + compound = CompoundNBT.from_object( + { + "test": ByteNBT(0), + "test2": IntNBT(0), + }, + name="compound", + ) + assert compound["test"] == ByteNBT(0, "test") + assert compound["test2"] == IntNBT(0, "test2") + with pytest.raises(KeyError): + compound["test3"] + + # Cannot index into a ByteNBT + with pytest.raises(TypeError): + compound["test"][0] # type:ignore + + listnbt = ListNBT.from_object([0, 1, 2], use_int_array=False) + assert listnbt[0] == ByteNBT(0) + assert listnbt[1] == ByteNBT(1) + assert listnbt[2] == ByteNBT(2) + with pytest.raises(IndexError): + listnbt[3] + with pytest.raises(TypeError): + listnbt["hello"] + + assert listnbt[-1] == ByteNBT(2) + assert listnbt[-2] == ByteNBT(1) + assert listnbt[-3] == ByteNBT(0) + + with pytest.raises(TypeError): + listnbt[object()] # type:ignore + + assert listnbt.value == [0, 1, 2] + assert listnbt.to_object() == [0, 1, 2] + assert ListNBT([]).value == [] + assert compound.to_object() == {"compound": {"test": 0, "test2": 0}} + assert compound.value == {"test": 0, "test2": 0} + assert ListNBT([IntNBT(0)]).value == [0] + + assert ByteNBT(12).value == 12 + assert ShortNBT(13).value == 13 + assert IntNBT(14).value == 14 + assert LongNBT(15).value == 15 + assert FloatNBT(0.5).value == 0.5 + assert DoubleNBT(0.6).value == 0.6 + assert ByteArrayNBT(b"test").value == b"test" + assert StringNBT("test").value == "test" + assert IntArrayNBT([0, 1, 2]).value == [0, 1, 2] + assert LongArrayNBT([0, 1, 2, 3]).value == [0, 1, 2, 3] + + invalid = ListNBT("Hello", "name") + with pytest.raises(AttributeError): + invalid[0] + + invalid = CompoundNBT([ByteNBT(0, "Byte"), "Hi"], "name") + with pytest.raises(AttributeError): + invalid["Byte"] # Attribute error is raised when the structure is incorrectly constructed + + +def test_to_object_morecases(): + """Test to_object with more edge cases.""" + + class CustomType: + def __bytes__(self): + return b"test" + + assert NBTag.from_object( + { + "bytearray": b"test", + "empty_list": [], + "empty_compound": {}, + "custom": CustomType(), + } + ).to_object() == { + "bytearray": b"test", + "empty_list": [], + "empty_compound": {}, + "custom": b"test", + } + + assert NBTag.to_object(CompoundNBT([])) == {} + + assert EndNBT().to_object() == {} # Does not add anything when doing dict.update + assert FloatNBT(0.5).to_object() == 0.5 + assert FloatNBT(0.5, "Hello World").to_object() == {"Hello World": 0.5} + assert ByteArrayNBT(b"test").to_object() == b"test" # Do not add name when there is no name + assert StringNBT("test").to_object() == "test" + assert StringNBT("test", "name").to_object() == {"name": "test"} + assert ListNBT([ByteNBT(0), ByteNBT(1)]).to_object() == [0, 1] + assert ListNBT([ByteNBT(0), ByteNBT(1)], "name").to_object() == {"name": [0, 1]} + assert IntArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] + assert LongArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] + + +def test_data_conversions(): + """Test data conversions using the built-in functions.""" + assert int(IntNBT(-1)) == -1 + assert float(FloatNBT(0.5)) == 0.5 + assert str(StringNBT("test")) == "test" + assert bytes(ByteArrayNBT(b"test")) == b"test" + assert list(ListNBT([ByteNBT(0), ByteNBT(1)])) == [ByteNBT(0), ByteNBT(1)] + assert dict(CompoundNBT([ByteNBT(0, "first"), ByteNBT(1, "second")])) == { + "first": ByteNBT(0, "first"), + "second": ByteNBT(1, "second"), + } + assert list(IntArrayNBT([0, 1, 2])) == [0, 1, 2] + assert list(LongArrayNBT([0, 1, 2])) == [0, 1, 2] + + +def test_init_nbtag_directly(): + """Test initializing NBTag directly.""" + with pytest.raises(TypeError): + NBTag(0) + with pytest.raises(TypeError): + NBTag(0, "test") + with pytest.raises(TypeError): + NBTag(0, name="test") + + +# endregion From 6e9d3125bf560fe7ae8fde7dd9d0dee2f2f57c10 Mon Sep 17 00:00:00 2001 From: Alexis Rossfelder Date: Tue, 30 Apr 2024 12:48:23 +0200 Subject: [PATCH 2/3] Simplify NBT implementation * Remove the TYPE class variable to rely in the actual type of each tag * Change the from_object and to_object to use schemas describing the data instead of using integer ranges and choosing types arbitrarily --- changes/257.feature.md | 7 +- mcproto/types/nbt.py | 890 +++++++++++++++++--------------- tests/mcproto/types/test_nbt.py | 705 ++++++++++++++++++------- 3 files changed, 999 insertions(+), 603 deletions(-) diff --git a/changes/257.feature.md b/changes/257.feature.md index 32fefeab..453b22bf 100644 --- a/changes/257.feature.md +++ b/changes/257.feature.md @@ -1,9 +1,12 @@ - Added the `NBTag` to deal with NBT data: - The `NBTag` class is the base class for all NBT tags and provides the basic functionality to serialize and deserialize NBT data from and to a `Buffer` object. - The classes `EndNBT`, `ByteNBT`, `ShortNBT`, `IntNBT`, `LongNBT`, `FloatNBT`, `DoubleNBT`, `ByteArrayNBT`, `StringNBT`, `ListNBT`, `CompoundNBT`, `IntArrayNBT`and `LongArrayNBT` were added and correspond to the NBT types described in the [NBT specification](https://wiki.vg/NBT#Specification). - - NBT tags can be created using the `NBTag.from_object()` method, which automatically selects the correct tag type based on the object's type and works recursively for lists and dictionaries. - - The `NBTag.to_object()` method can be used to convert an NBT tag back to a Python object. + - NBT tags can be created using the `NBTag.from_object()` method and a schema that describes the NBT tag structure. + Compound tags are represented as dictionaries, list tags as lists, and primitive tags as their respective Python types. + The implementation allows to add custom classes to the schema to handle custom NBT tags if they inherit the `:class: NBTagConvertible` class. + - The `NBTag.to_object()` method can be used to convert an NBT tag back to a Python object. Use include_schema=True to include the schema in the output, and `include_name=True` to include the name of the tag in the output. In that case the output will be a dictionary with a single key that is the name of the tag and the value is the object representation of the tag. - The `NBTag.serialize()` can be used to serialize an NBT tag to a new `Buffer` object. - The `NBTag.deserialize(buffer)` can be used to deserialize an NBT tag from a `Buffer` object. - If the buffer already exists, the `NBTag.write_to(buffer, with_type=True, with_name=True)` method can be used to write the NBT tag to the buffer (and in that case with the type and name in the right format). - The `NBTag.read_from(buffer, with_type=True, with_name=True)` method can be used to read an NBT tag from the buffer (and in that case with the type and name in the right format). + - The `NBTag.value` property can be used to get the value of the NBT tag as a Python object. diff --git a/mcproto/types/nbt.py b/mcproto/types/nbt.py index e8a057bb..8932b5cf 100644 --- a/mcproto/types/nbt.py +++ b/mcproto/types/nbt.py @@ -1,17 +1,18 @@ from __future__ import annotations -import warnings -from abc import ABCMeta +from abc import abstractmethod from enum import IntEnum -from typing import ClassVar, List, Mapping, Union, cast +from typing import Iterator, List, Mapping, Sequence, Tuple, Type, Union, cast -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, override +from typing import Protocol, runtime_checkable # Have to be imported from the same place from mcproto.buffer import Buffer from mcproto.protocol.base_io import StructFormat from mcproto.types.abc import MCType __all__ = [ + "NBTagConvertible", "NBTagType", "NBTag", "EndNBT", @@ -29,103 +30,105 @@ "LongArrayNBT", ] - -# region NBT Specification """ -Source : https://web.archive.org/web/20110723210920/http://www.minecraft.net/docs/NBT.txt +Implementation of the NBT (Named Binary Tag) format used in Minecraft as described in the NBT specification +(:seealso: :class:`NBTagType`). +""" +# region NBT Specification -Named Binary Tag specification -NBT (Named Binary Tag) is a tag based binary format designed to carry large amounts of binary data with smaller amounts -of additional data. -An NBT file consists of a single GZIPped Named Tag of type TAG_Compound. +class NBTagType(IntEnum): + """Enumeration of the different types of NBT tags. -A Named Tag has the following format: + Source : https://web.archive.org/web/20110723210920/http://www.minecraft.net/docs/NBT.txt - byte tagType - TAG_String name - [payload] + Named Binary Tag specification -The tagType is a single byte defining the contents of the payload of the tag. + NBT (Named Binary Tag) is a tag based binary format designed to carry large amounts of binary data with smaller + amounts of additional data. + An NBT file consists of a single GZIPped Named Tag of type TAG_Compound. -The name is a descriptive name, and can be anything (eg "cat", "banana", "Hello World!"). It has nothing to do with the -tagType. -The purpose for this name is to name tags so parsing is easier and can be made to only look for certain recognized tag -names. -Exception: If tagType is TAG_End, the name is skipped and assumed to be "". + A Named Tag has the following format: -The [payload] varies by tagType. + byte tagType + TAG_String name + [payload] -Note that ONLY Named Tags carry the name and tagType data. Explicitly identified Tags (such as TAG_String above) only -contains the payload. + The tagType is a single byte defining the contents of the payload of the tag. + The name is a descriptive name, and can be anything (eg "cat", "banana", "Hello World!"). It has nothing to do with + the tagType. + The purpose for this name is to name tags so parsing is easier and can be made to only look for certain recognized + tag names. + Exception: If tagType is TAG_End, the name is skipped and assumed to be "". -The tag types and respective payloads are: + The [payload] varies by tagType. - TYPE: 0 NAME: TAG_End - Payload: None. - Note: This tag is used to mark the end of a list. - Cannot be named! If type 0 appears where a Named Tag is expected, the name is assumed to be "". - (In other words, this Tag is always just a single 0 byte when named, and nothing in all other cases) + Note that ONLY Named Tags carry the name and tagType data. Explicitly identified Tags (such as TAG_String above) + only contains the payload. - TYPE: 1 NAME: TAG_Byte - Payload: A single signed byte (8 bits) - TYPE: 2 NAME: TAG_Short - Payload: A signed short (16 bits, big endian) + The tag types and respective payloads are: - TYPE: 3 NAME: TAG_Int - Payload: A signed short (32 bits, big endian) + TYPE: 0 NAME: TAG_End + Payload: None. + Note: This tag is used to mark the end of a list. + Cannot be named! If type 0 appears where a Named Tag is expected, the name is assumed to be "". + (In other words, this Tag is always just a single 0 byte when named, and nothing in all other cases) - TYPE: 4 NAME: TAG_Long - Payload: A signed long (64 bits, big endian) + TYPE: 1 NAME: TAG_Byte + Payload: A single signed byte (8 bits) - TYPE: 5 NAME: TAG_Float - Payload: A floating point value (32 bits, big endian, IEEE 754-2008, binary32) + TYPE: 2 NAME: TAG_Short + Payload: A signed short (16 bits, big endian) - TYPE: 6 NAME: TAG_Double - Payload: A floating point value (64 bits, big endian, IEEE 754-2008, binary64) + TYPE: 3 NAME: TAG_Int + Payload: A signed short (32 bits, big endian) - TYPE: 7 NAME: TAG_Byte_Array - Payload: TAG_Int length - An array of bytes of unspecified format. The length of this array is bytes + TYPE: 4 NAME: TAG_Long + Payload: A signed long (64 bits, big endian) - TYPE: 8 NAME: TAG_String - Payload: TAG_Short length - An array of bytes defining a string in UTF-8 format. The length of this array is bytes + TYPE: 5 NAME: TAG_Float + Payload: A floating point value (32 bits, big endian, IEEE 754-2008, binary32) - TYPE: 9 NAME: TAG_List - Payload: TAG_Byte tagId - TAG_Int length - A sequential list of Tags (not Named Tags), of type . The length of this array is Tags - Notes: All tags share the same type. + TYPE: 6 NAME: TAG_Double + Payload: A floating point value (64 bits, big endian, IEEE 754-2008, binary64) - TYPE: 10 NAME: TAG_Compound - Payload: A sequential list of Named Tags. This array keeps going until a TAG_End is found. - TAG_End end - Notes: If there's a nested TAG_Compound within this tag, that one will also have a TAG_End, so simply reading - until the next TAG_End will not work. - The names of the named tags have to be unique within each TAG_Compound - The order of the tags is not guaranteed. + TYPE: 7 NAME: TAG_Byte_Array + Payload: TAG_Int length + An array of bytes of unspecified format. The length of this array is bytes + TYPE: 8 NAME: TAG_String + Payload: TAG_Short length + An array of bytes defining a string in UTF-8 format. The length of this array is bytes - // NEW TAGS - TYPE: 11 NAME: TAG_Int_Array - Payload: TAG_Int length - An array of integers. The length of this array is integers + TYPE: 9 NAME: TAG_List + Payload: TAG_Byte tagId + TAG_Int length + A sequential list of Tags (not Named Tags), of type . The length of this array is + Tags + Notes: All tags share the same type. - TYPE: 12 NAME: TAG_Long_Array - Payload: TAG_Int length - An array of longs. The length of this array is longs + TYPE: 10 NAME: TAG_Compound + Payload: A sequential list of Named Tags. This array keeps going until a TAG_End is found. + TAG_End end + Notes: If there's a nested TAG_Compound within this tag, that one will also have a TAG_End, so simply reading + until the next TAG_End will not work. + The names of the named tags have to be unique within each TAG_Compound + The order of the tags is not guaranteed. -""" -# endregion -# region NBT base classes/types + // NEW TAGS + TYPE: 11 NAME: TAG_Int_Array + Payload: TAG_Int length + An array of integers. The length of this array is integers + TYPE: 12 NAME: TAG_Long_Array + Payload: TAG_Int length + An array of longs. The length of this array is longs -class NBTagType(IntEnum): - """Types of NBT tags.""" + + """ END = 0 BYTE = 1 @@ -145,44 +148,62 @@ class NBTagType(IntEnum): PayloadType: TypeAlias = Union[ int, float, - bytearray, bytes, str, - List["PayloadType"], - Mapping[str, "PayloadType"], - List[int], "NBTag", - List["NBTag"], + Sequence["PayloadType"], + Mapping[str, "PayloadType"], ] -class _MetaNBTag(ABCMeta): - """Metaclass for NBT tags.""" +@runtime_checkable +class NBTagConvertible(Protocol): + """Protocol for objects that can be converted to an NBT tag.""" + + __slots__ = () + + def to_nbt(self, name: str = "") -> NBTag: + """Convert the object to an NBT tag. + + :param name: The name of the tag. + + :return: The NBT tag created from the object. + """ + ... - TYPE: NBTagType = NBTagType.COMPOUND - def __new__(cls, name: str, bases: tuple[type], namespace: dict, **kwargs): - new_cls: NBTag = super().__new__(cls, name, bases, namespace) # type: ignore - if name != "NBTag": - NBTag.ASSOCIATED_TYPES[new_cls.TYPE] = new_cls # type: ignore - return new_cls +FromObjectType: TypeAlias = Union[ + int, + float, + bytes, + str, + NBTagConvertible, + Sequence["FromObjectType"], + Mapping[str, "FromObjectType"], +] +FromObjectSchema: TypeAlias = Union[ + Type["NBTag"], + Type[NBTagConvertible], + Sequence["FromObjectSchema"], + Mapping[str, "FromObjectSchema"], +] -class NBTag(MCType, metaclass=_MetaNBTag): - """Base class for NBT tags.""" - __slots__ = ("name", "payload") +class NBTag(MCType, NBTagConvertible): + """Base class for NBT tags. - TYPE: ClassVar[NBTagType] = NBTagType.COMPOUND + In MC v1.20.2+ the type and name of the root tag are not written to the buffer, and unless specified, the type of + the tag is assumed to be TAG_Compound. + """ - ASSOCIATED_TYPES: ClassVar[dict[NBTagType, type[NBTag]]] = {} + __slots__ = ("name", "payload") def __init__(self, payload: PayloadType, name: str = ""): - if self.__class__ == NBTag: - raise TypeError("Cannot instantiate an NBTag object directly, use a subclass instead.") self.name = name self.payload = payload + @override def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: """Serialize the NBT tag to a buffer. @@ -196,21 +217,19 @@ def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: self.write_to(buf, with_name=with_name, with_type=with_type) return buf - def _write_header(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> bool: + def _write_header(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: if with_type: - buf.write_value(StructFormat.BYTE, self.TYPE.value) - if self.TYPE == NBTagType.END: - return False - if with_name: - if not self.name: - raise ValueError("Named tags must have a name.") + tag_type = _get_tag_type(self) + buf.write_value(StructFormat.BYTE, tag_type.value) + if with_name and self.name: StringNBT(self.name).write_to(buf, with_type=False, with_name=False) - return True + @abstractmethod def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the NBT tag to the buffer.""" ... + @override @classmethod def deserialize(cls, buf: Buffer, with_name: bool = True, with_type: bool = True) -> NBTag: """Deserialize the NBT tag. @@ -229,7 +248,7 @@ def deserialize(cls, buf: Buffer, with_name: bool = True, with_type: bool = True """ name, tag_type = cls._read_header(buf, with_name=with_name, read_type=with_type) - tag_class = NBTag.ASSOCIATED_TYPES[tag_type] + tag_class = ASSOCIATED_TYPES[tag_type] if cls not in (NBTag, tag_class): raise TypeError(f"Expected a {cls.__name__} tag, but found a different tag ({tag_class.__name__}).") @@ -251,16 +270,17 @@ def _read_header(cls, buf: Buffer, read_type: bool = True, with_name: bool = Tru :note: It is possible that this function reads nothing from the buffer if both with_name and read_type are set to False. """ - tag_type: NBTagType = cls.TYPE # default value if read_type: try: tag_type = NBTagType(buf.read_value(StructFormat.BYTE)) - except OSError: - raise IOError("Buffer is empty.") from None - except ValueError: - raise TypeError("Invalid tag type.") from None - - if tag_type == NBTagType.END: + except OSError as exc: + raise IOError("Buffer is empty.") from exc + except ValueError as exc: + raise TypeError("Invalid tag type.") from exc + else: + tag_type = _get_tag_type(cls) + + if tag_type is NBTagType.END: return "", tag_type name = StringNBT.read_from(buf, with_type=False, with_name=False).value if with_name else "" @@ -268,6 +288,7 @@ def _read_header(cls, buf: Buffer, read_type: bool = True, with_name: bool = Tru return name, tag_type @classmethod + @abstractmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> NBTag: """Read the NBT tag from the buffer. @@ -281,142 +302,133 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The NBT tag. """ - return cls.deserialize(buf, with_name=with_name, with_type=with_type) + ... @staticmethod - def from_object(data: object, /, name: str = "", *, use_int_array: bool = True) -> NBTag: # noqa: PLR0911,PLR0912 - """Create an NBT tag from an arbitrary (compatible) Python object. - - :param data: The object to convert to an NBT tag. - :param use_int_array: Whether to use IntArrayNBT and LongArrayNBT for lists of integers. - If set to False, all lists of integers will be considered as ListNBT. - :param name: The name of the resulting tag. Used for recursive calls. + def from_object(data: FromObjectType, schema: FromObjectSchema, name: str = "") -> NBTag: + """Create an NBT tag from a dictionary. - :return: The NBT tag representing the object. + :param data: The dictionary to create the NBT tag from. + :param schema: The schema used to create the NBT tags. - :note: The function will attempt to convert the object to an NBT tag in the following way: - - If the object is a dictionary with a single key, the key will be used as the name of the tag. - - If the object is an integer, it will be converted to a ByteNBT, ShortNBT, IntNBT, or LongNBT tag - depending on the value. - - If the object is a list, it will be converted to a ListNBT tag. - - If the object is a dictionary, it will be converted to a CompoundNBT tag. - - If the object is a string, it will be converted to a StringNBT tag. - - If the object is a float, it will be converted to a FloatNBT tag. - - If the object can be serialized to bytes, it will be converted to a ByteArrayNBT tag. + If the schema is a list, the data must be a list and the schema must either contain a single element + representing the type of the elements in the list or multiple dictionaries or lists representing the types + of the elements in the list since they are the only types that have a variable type. + Example: + ```python + schema = [IntNBT] + data = [1, 2, 3] + schema = [[IntNBT], [StringNBT]] + data = [[1, 2, 3], ["a", "b", "c"]] + ``` - - If you want an object to be serialized in a specific way, you can implement: + If the schema is a dictionary, the data must be a dictionary and the schema must contain the keys and the + types of the values in the dictionary. + Example: ```python - def to_nbt(self, name: str = "") -> NBTag: - ... + schema = {"key": IntNBT} + data = {"key": 1} ``` + + If the schema is a subclass of NBTag, the data will be passed to the constructor of the schema. + If the schema is not a list, dictionary or subclass of NBTag, the data will be converted to an NBT tag + using the `to_nbt` method of the data. + + :param name: The name of the NBT tag. + + :return: The NBT tag created from the dictionary. """ - if hasattr(data, "to_nbt"): # For objects that can be converted to NBT - return data.to_nbt(name=name) # type: ignore - - if isinstance(data, int): - if -(1 << 7) <= data < 1 << 7: - return ByteNBT(data, name=name) - if -(1 << 15) <= data < 1 << 15: - return ShortNBT(data, name=name) - if -(1 << 31) <= data < 1 << 31: - return IntNBT(data, name=name) - if -(1 << 63) <= data < 1 << 63: - return LongNBT(data, name=name) - raise ValueError(f"Integer {data} is out of range.") - if isinstance(data, float): - return FloatNBT(data, name=name) - if isinstance(data, str): - return StringNBT(data, name=name) - if isinstance(data, (bytearray, bytes)): - if isinstance(data, bytearray): - data = bytes(data) - return ByteArrayNBT(data, name=name) - if isinstance(data, list): - if not data: - # Type END is used to mark an empty list - return ListNBT([], name=name) - first_type = type(data[0]) - if any(type(item) != first_type for item in data): - raise TypeError("All items in a list must be of the same type.") - - if issubclass(first_type, int) and use_int_array: - # Check the range of the integers in the list - use_int = all(-(1 << 31) <= item < 1 << 31 for item in data) - use_long = all(-(1 << 63) <= item < 1 << 63 for item in data) - if use_int: - return IntArrayNBT(data, name=name) - if not use_long: # Too big to fit in a long, won't fit in a List of Longs either - raise ValueError("Integer list contains values out of range.") - return LongArrayNBT(data, name=name) - return ListNBT([NBTag.from_object(item, use_int_array=use_int_array) for item in data], name=name) - if isinstance(data, dict): - if len(data) == 0: - return CompoundNBT([], name=name) - if len(data) == 1 and name == "": - key, value = next(iter(data.items())) - return NBTag.from_object(value, name=key, use_int_array=use_int_array) - payload = [] + if isinstance(schema, (list, tuple)): + if not isinstance(data, list): + raise TypeError("Expected a list, but found a different type.") + payload: list[NBTag] = [] + if len(schema) > 1: + if not all(isinstance(item, (list, dict)) for item in schema): + raise TypeError("Expected a list of lists or dictionaries, but found a different type.") + if len(schema) != len(data): + raise ValueError("The schema and the data must have the same length.") + for item, sub_schema in zip(data, schema): + payload.append(NBTag.from_object(item, sub_schema)) + else: + if len(schema) == 0 and len(data) > 0: + raise ValueError("The schema is empty, but the data is not.") + if len(schema) == 0: + return ListNBT([], name=name) + + schema = schema[0] + for item in data: + payload.append(NBTag.from_object(item, schema)) + return ListNBT(payload, name=name) + if isinstance(schema, dict): + if not isinstance(data, dict): + raise TypeError("Expected a dictionary, but found a different type.") + payload: list[NBTag] = [] for key, value in data.items(): - tag = NBTag.from_object(value, name=key, use_int_array=use_int_array) - payload.append(tag) - return CompoundNBT(payload, name) - if data is None: - warnings.warn("Converting None to an END tag.", stacklevel=2) - return EndNBT() # Should not be used - - try: - # Check if the object can be converted to bytes - return ByteArrayNBT(bytes(data), name=name) # type: ignore - except (TypeError, ValueError): - pass - raise TypeError(f"Cannot convert object of type {type(data)} to an NBT tag.") - - def to_object(self) -> Mapping[str, PayloadType] | PayloadType: - """Convert the NBT payload to a dictionary.""" - return CompoundNBT(self.payload).to_object() # allow NBTag.to_object to act as a dict - - def __getitem__(self, key: str | int) -> PayloadType: - """Get a tag from the list or compound tag.""" - if self.TYPE not in (NBTagType.LIST, NBTagType.COMPOUND, NBTagType.INT_ARRAY, NBTagType.LONG_ARRAY): - raise TypeError(f"Cannot get a tag by index from a non-LIST or non-COMPOUND tag ({self.TYPE}).") - - if not isinstance(self.payload, list): - raise AttributeError( - f"The payload of the tag is not a list ({self.TYPE}).\n" - "Check that the initialization of the tag is correct." + payload.append(NBTag.from_object(value, schema[key], name=key)) + return CompoundNBT(payload, name=name) + if not isinstance(schema, type) or not issubclass(schema, (NBTag, NBTagConvertible)): # type: ignore + raise TypeError("The schema must be a list, dict or a subclass of either NBTag or NBTagConvertible.") + if isinstance(data, schema): + return data.to_nbt(name=name) + schema = cast(Type[NBTag], schema) # Last option + if issubclass(schema, (CompoundNBT, ListNBT)): + raise ValueError("The schema must specify the type of the elements in CompoundNBT and ListNBT tags.") + if isinstance(data, dict): + if len(data) != 1: + raise ValueError("Expected a dictionary with a single key-value pair.") + key, value = next(iter(data.items())) + return schema.from_object(value, schema, name=key) + if not isinstance(data, (bytes, str, int, float, list)): + raise TypeError(f"Expected a bytes, str, int, float, but found {type(data).__name__}.") + if isinstance(data, list) and not all(isinstance(item, int) for item in data): + raise TypeError("Expected a list of integers.") # LongArrayNBT, IntArrayNBT + + data = cast(Union[bytes, str, int, float, List[int]], data) + return schema(data, name=name) + + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> PayloadType | Mapping[str, PayloadType] | tuple[PayloadType | Mapping[str, PayloadType], FromObjectSchema]: + """Convert the NBT tag to a python object. + + :param include_schema: Whether to return a schema describing the types of the original tag. + :param include_name: Whether to include the name of the tag in the output. + If the tag has no name, the name will be set to "". + + :return: Either : + - A python object representing the payload of the tag. (default) + - A dictionary containing the name associated with a python object representing the payload of the tag. + - A tuple which includes one of the above and a schema describing the types of the original tag. + """ + if type(self) is EndNBT: + raise NotImplementedError("Cannot convert an EndNBT tag to a python object.") + if type(self) in (CompoundNBT, ListNBT): + raise TypeError( + f"Use the `{type(self).__name__}.to_object()` method to convert the tag to a python object." ) - if not isinstance(key, (str, int)): # type: ignore - raise TypeError("Key must be a string or an integer.") - - if isinstance(key, str): - if self.TYPE != NBTagType.COMPOUND: - raise TypeError(f"Cannot get a tag by name from a non-COMPOUND tag ({self.TYPE}).") - if not all(isinstance(tag, NBTag) for tag in self.payload): - raise AttributeError("The payload of the tag is not a list of NBTag objects.") - for tag in self.payload: - tag = cast(NBTag, tag) - if tag.name == key: - return tag - raise KeyError(f"No tag with the name {key!r} found.") - - # Key is an integer - if key < -len(self.payload) or key >= len(self.payload): - raise IndexError(f"Index {key} out of range.") - return self.payload[key] + result = self.payload if not include_name else {self.name: self.payload} + if include_schema: + return result, type(self) + return result + @override def __repr__(self) -> str: if self.name: - return f"{self.__class__.__name__}[{self.name!r}]({self.payload!r})" - return f"{self.__class__.__name__}({self.payload!r})" + return f"{type(self).__name__}[{self.name!r}]({self.payload!r})" + return f"{type(self).__name__}({self.payload!r})" + @override def __eq__(self, other: object) -> bool: """Check equality between two NBT tags.""" if not isinstance(other, NBTag): raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") - return self.name == other.name and self.TYPE == other.TYPE and self.payload == other.payload + if type(self) is not type(other): + return False + return self.name == other.name and self.payload == other.payload + @override def to_nbt(self, name: str = "") -> NBTag: """Convert the object to an NBT tag. @@ -426,12 +438,10 @@ def to_nbt(self, name: str = "") -> NBTag: return self @property + @abstractmethod def value(self) -> PayloadType: """Get the payload of the NBT tag in a python-friendly format.""" - obj = self.to_object() - if isinstance(obj, dict) and self.name: - return obj[self.name] - return obj + ... # endregion @@ -441,22 +451,23 @@ def value(self) -> PayloadType: class EndNBT(NBTag): """Sentinel tag used to mark the end of a TAG_Compound.""" - TYPE = NBTagType.END __slots__ = () def __init__(self): """Create a new EndNBT tag.""" super().__init__(0, name="") - def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = False) -> None: """Write the EndNBT tag to the buffer. :param buf: The buffer to write to. :param with_type: Whether to include the type of the tag in the serialization. :param with_name: Whether to include the name of the tag in the serialization. """ - self._write_header(buf, with_type=with_type, with_name=with_name) + self._write_header(buf, with_type=with_type, with_name=False) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> EndNBT: """Read the EndNBT tag from the buffer. @@ -469,26 +480,37 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The EndNBT tag. """ _, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") return EndNBT() - def to_object(self) -> Mapping[str, PayloadType]: + @override + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> PayloadType | Mapping[str, PayloadType]: """Convert the EndNBT tag to a python object. - :return: An empty dictionary. + :param include_schema: Whether to return a schema describing the types of the original tag. + :param include_name: Whether to include the name of the tag in the output. + + :return: None """ - return {} + return NotImplemented + + @property + @override + def value(self) -> PayloadType: + """Get the payload of the EndNBT tag in a python-friendly format.""" + return NotImplemented class ByteNBT(NBTag): """NBT tag representing a single byte value, represented as a signed 8-bit integer.""" - TYPE = NBTagType.BYTE - __slots__ = () payload: int + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the ByteNBT tag to the buffer. @@ -502,6 +524,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) buf.write_value(StructFormat.BYTE, self.payload) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteNBT: """Read the ByteNBT tag from the buffer. @@ -514,8 +537,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The ByteNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 1: raise IOError("Buffer does not contain enough data to read a byte. (Empty buffer)") @@ -526,17 +549,8 @@ def __int__(self) -> int: """Get the integer value of the ByteNBT tag.""" return self.payload - def to_object(self) -> Mapping[str, int] | int: - """Convert the ByteNBT tag to a python object. - - :return: A dictionary containing the name and the integer value of the tag. If the tag has no name, the value - will be returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - @property + @override def value(self) -> int: """Get the integer value of the IntNBT tag.""" return self.payload @@ -545,10 +559,9 @@ def value(self) -> int: class ShortNBT(ByteNBT): """NBT tag representing a short value, represented as a signed 16-bit integer.""" - TYPE = NBTagType.SHORT - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the ShortNBT tag to the buffer. @@ -565,6 +578,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) buf.write(self.payload.to_bytes(2, "big", signed=True)) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ShortNBT: """Read the ShortNBT tag from the buffer. @@ -577,8 +591,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The ShortNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 2: raise IOError("Buffer does not contain enough data to read a short.") @@ -589,10 +603,9 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) class IntNBT(ByteNBT): """NBT tag representing an integer value, represented as a signed 32-bit integer.""" - TYPE = NBTagType.INT - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the IntNBT tag to the buffer. @@ -610,6 +623,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) # No more messing around with the struct, we want 32 bits of data no matter what buf.write(self.payload.to_bytes(4, "big", signed=True)) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntNBT: """Read the IntNBT tag from the buffer. @@ -622,8 +636,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The IntNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 4: raise IOError("Buffer does not contain enough data to read an int.") @@ -634,10 +648,9 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) class LongNBT(ByteNBT): """NBT tag representing a long value, represented as a signed 64-bit integer.""" - TYPE = NBTagType.LONG - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the LongNBT tag to the buffer. @@ -655,6 +668,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) # No more messing around with the struct, we want 64 bits of data no matter what buf.write(self.payload.to_bytes(8, "big", signed=True)) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongNBT: """Read the LongNBT tag from the buffer. @@ -667,8 +681,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The LongNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 8: raise IOError("Buffer does not contain enough data to read a long.") @@ -680,12 +694,11 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) class FloatNBT(NBTag): """NBT tag representing a floating-point value, represented as a 32-bit IEEE 754-2008 binary32 value.""" - TYPE = NBTagType.FLOAT - payload: float __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the FloatNBT tag to the buffer. @@ -698,6 +711,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) self._write_header(buf, with_type=with_type, with_name=with_name) buf.write_value(StructFormat.FLOAT, self.payload) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> FloatNBT: """Read the FloatNBT tag from the buffer. @@ -710,8 +724,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The FloatNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 4: raise IOError("Buffer does not contain enough data to read a float.") @@ -722,16 +736,7 @@ def __float__(self) -> float: """Get the float value of the FloatNBT tag.""" return self.payload - def to_object(self) -> Mapping[str, float] | float: - """Convert the FloatNBT tag to a python object. - - :return: A dictionary containing the name and the float value of the tag. If the tag has no name, the value - will be returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - + @override def __eq__(self, other: object) -> bool: """Check equality between two FloatNBT tags. @@ -744,14 +749,15 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, NBTag): raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") # Compare the float values with a small epsilon - if not (self.name == other.name and self.TYPE == other.TYPE): + if type(self) is not type(other): + return False + other.payload = cast(float, other.payload) + if self.name != other.name: return False - if not isinstance(other, self.__class__): # pragma: no cover - return False # Should not happen if nobody messes with the TYPE attribute - return abs(self.payload - other.payload) < 1e-6 @property + @override def value(self) -> float: """Get the float value of the FloatNBT tag.""" return self.payload @@ -760,10 +766,9 @@ def value(self) -> float: class DoubleNBT(FloatNBT): """NBT tag representing a double-precision floating-point value, represented as a 64-bit IEEE 754-2008 binary64.""" - TYPE = NBTagType.DOUBLE - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the DoubleNBT tag to the buffer. @@ -776,6 +781,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) self._write_header(buf, with_type=with_type, with_name=with_name) buf.write_value(StructFormat.DOUBLE, self.payload) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> DoubleNBT: """Read the DoubleNBT tag from the buffer. @@ -788,8 +794,8 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The DoubleNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") if buf.remaining < 8: raise IOError("Buffer does not contain enough data to read a double.") @@ -800,12 +806,11 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) class ByteArrayNBT(NBTag): """NBT tag representing an array of bytes. The length of the array is stored as a signed 32-bit integer.""" - TYPE = NBTagType.BYTE_ARRAY - __slots__ = () - payload: bytearray + payload: bytes + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the ByteArrayNBT tag to the buffer. @@ -819,6 +824,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) IntNBT(len(self.payload)).write_to(buf, with_type=False, with_name=False) buf.write(self.payload) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteArrayNBT: """Read the ByteArrayNBT tag from the buffer. @@ -831,12 +837,12 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The ByteArrayNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") try: length = IntNBT.read_from(buf, with_type=False, with_name=False).value - except IOError: - raise IOError("Buffer does not contain enough data to read a byte array.") from None + except IOError as exc: + raise IOError("Buffer does not contain enough data to read a byte array.") from exc if length < 0: raise ValueError("Invalid byte array length.") @@ -846,32 +852,24 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) f"Buffer does not contain enough data to read the byte array ({buf.remaining} < {length} bytes)." ) - return ByteArrayNBT(buf.read(length), name=name) + return ByteArrayNBT(bytes(buf.read(length)), name=name) def __bytes__(self) -> bytes: """Get the bytes value of the ByteArrayNBT tag.""" return self.payload - def to_object(self) -> Mapping[str, bytearray] | bytearray: - """Convert the ByteArrayNBT tag to a python object. - - :return: A dictionary containing the name and the byte array value of the tag. If the tag has no name, the - value will be returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - + @override def __repr__(self) -> str: """Get a string representation of the ByteArrayNBT tag.""" if self.name: - return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)})" + return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)})" if len(self.payload) < 8: - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" - return f"{self.__class__.__name__}(length={len(self.payload)}, {bytes(self.payload[:7])!r}...)" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}(length={len(self.payload)}, {bytes(self.payload[:7])!r}...)" @property - def value(self) -> bytearray: + @override + def value(self) -> bytes: """Get the bytes value of the ByteArrayNBT tag.""" return self.payload @@ -879,12 +877,11 @@ def value(self) -> bytearray: class StringNBT(NBTag): """NBT tag representing an UTF-8 string value. The length of the string is stored as a signed 16-bit integer.""" - TYPE = NBTagType.STRING - __slots__ = () payload: str + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the StringNBT tag to the buffer. @@ -899,10 +896,11 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) # Check the length of the string (can't generate strings that long in tests) raise ValueError("Maximum character limit for writing strings is 32767 characters.") # pragma: no cover - data = bytearray(self.payload, "utf-8") + data = bytes(self.payload, "utf-8") ShortNBT(len(data)).write_to(buf, with_type=False, with_name=False) buf.write(data) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> StringNBT: """Read the StringNBT tag from the buffer. @@ -915,12 +913,12 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The StringNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") try: length = ShortNBT.read_from(buf, with_type=False, with_name=False).value - except IOError: - raise IOError("Buffer does not contain enough data to read a string.") from None + except IOError as exc: + raise IOError("Buffer does not contain enough data to read a string.") from exc if length < 0: raise ValueError("Invalid string length.") @@ -933,21 +931,13 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) except UnicodeDecodeError: raise # We want to know it + @override def __str__(self) -> str: """Get the string value of the StringNBT tag.""" return self.payload - def to_object(self) -> Mapping[str, str] | str: - """Convert the StringNBT tag to a python object. - - :return: A dictionary containing the name and the string value of the tag. If the tag has no name, the value - will be returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - @property + @override def value(self) -> str: """Get the string value of the StringNBT tag.""" return self.payload @@ -956,12 +946,11 @@ def value(self) -> str: class ListNBT(NBTag): """NBT tag representing a list of tags. All tags in the list must be of the same type.""" - TYPE = NBTagType.LIST - __slots__ = () payload: list[NBTag] + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the ListNBT tag to the buffer. @@ -986,17 +975,18 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) "objects to tags first." ) - tag_type = self.payload[0].TYPE + tag_type = _get_tag_type(self.payload[0]) ByteNBT(tag_type).write_to(buf, with_name=False, with_type=False) IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) for tag in self.payload: - if tag_type != tag.TYPE: + if tag_type != _get_tag_type(tag): raise ValueError(f"All tags in a list must be of the same type, got tag {tag!r}") if tag.name != "": raise ValueError(f"All tags in a list must be unnamed, got tag {tag!r}") tag.write_to(buf, with_type=False, with_name=False) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ListNBT: """Read the ListNBT tag from the buffer. @@ -1009,69 +999,104 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The ListNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") list_tag_type = ByteNBT.read_from(buf, with_type=False, with_name=False).payload try: length = IntNBT.read_from(buf, with_type=False, with_name=False).value - except IOError: - raise IOError("Buffer does not contain enough data to read a list.") from None + except IOError as exc: + raise IOError("Buffer does not contain enough data to read a list.") from exc - if length < 0 or list_tag_type == NBTagType.END: + if length < 1 or list_tag_type is NBTagType.END: return ListNBT([], name=name) try: list_tag_type = NBTagType(list_tag_type) - except ValueError: - raise TypeError(f"Unknown tag type {list_tag_type}.") from None + except ValueError as exc: + raise TypeError(f"Unknown tag type {list_tag_type}.") from exc - list_type_class = NBTag.ASSOCIATED_TYPES.get(list_tag_type, NBTag) - if list_type_class == NBTag: + list_type_class = ASSOCIATED_TYPES.get(list_tag_type, NBTag) + if list_type_class is NBTag: raise TypeError(f"Unknown tag type {list_tag_type}.") # pragma: no cover try: - payload = [ - # The type is already known, so we don't need to read it again - # List items are unnamed, so we don't need to read the name - list_type_class.read_from(buf, with_type=False, with_name=False) - for _ in range(length) - ] - except IOError: - raise IOError("Buffer does not contain enough data to read the list.") from None + payload = [list_type_class.read_from(buf, with_type=False, with_name=False) for _ in range(length)] + except IOError as exc: + raise IOError("Buffer does not contain enough data to read the list.") from exc return ListNBT(payload, name=name) - def __iter__(self): + def __iter__(self) -> Iterator[NBTag]: """Iterate over the tags in the list.""" yield from self.payload + @override def __repr__(self) -> str: """Get a string representation of the ListNBT tag.""" if self.name: - return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" if len(self.payload) < 8: - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" - - def to_object(self) -> Mapping[str, list[PayloadType]] | list[PayloadType]: + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + + @override + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> ( + list[PayloadType] + | Mapping[str, list[PayloadType]] + | tuple[list[PayloadType] | Mapping[str, list[PayloadType]], list[FromObjectSchema]] + ): """Convert the ListNBT tag to a python object. - :return: A dictionary containing the name and the list of tags. If the tag has no name, the list will be - returned directly. + :param include_schema: Whether to return a schema describing the types of the original tag. + :param include_name: Whether to include the name of the tag in the output. + If the tag has no name, the name will be set to "". + + :return: Either : + - A list containing the payload of the tag. (default) + - A dictionary containing the name associated with a list containing the payload of the tag. + - A tuple which includes one of the above and a list of schemas describing the types of the original tag. """ - self.payload: list[NBTag] - if self.name: - return {self.name: [tag.to_object() for tag in self.payload]} # Extract the (unnamed) object from each tag - return [tag.to_object() for tag in self.payload] # Extract the (unnamed) object from each tag + result = [tag.to_object() for tag in self.payload] + result = cast(List[PayloadType], result) + result = result if not include_name else {self.name: result} + if include_schema: + subschemas = [ + cast( + Tuple[PayloadType, FromObjectSchema], + tag.to_object(include_schema=True), + )[1] + for tag in self.payload + ] + if len(result) == 0: + return result, [] + + first = subschemas[0] + if all(schema == first for schema in subschemas): + return result, [first] + + if not isinstance(first, (dict, list)): + raise TypeError(f"The schema must contain either a dict or a list. Found {first!r}") + # This will take care of ensuring either everything is a dict or a list + if not all(isinstance(schema, type(first)) for schema in subschemas): + raise TypeError(f"All items in the list must have the same type. Found {subschemas!r}") + return result, subschemas + return result + + @property + @override + def value(self) -> list[PayloadType]: + """Get the payload of the ListNBT tag in a python-friendly format.""" + return [tag.value for tag in self.payload] class CompoundNBT(NBTag): """NBT tag representing a compound of named tags.""" - TYPE = NBTagType.COMPOUND - __slots__ = () payload: list[NBTag] + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the CompoundNBT tag to the buffer. @@ -1102,6 +1127,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) tag.write_to(buf) EndNBT().write_to(buf, with_name=False, with_type=True) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> CompoundNBT: """Read the CompoundNBT tag from the buffer. @@ -1114,16 +1140,16 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) :return: The CompoundNBT tag. """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) - if tag_type != cls.TYPE: - raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") - payload = [] + payload: list[NBTag] = [] while True: child_name, child_type = cls._read_header(buf, with_name=True, read_type=True) - if child_type == NBTagType.END: + if child_type is NBTagType.END: break # The name and type of the tag have already been read - tag = NBTag.ASSOCIATED_TYPES[child_type].read_from(buf, with_type=False, with_name=False) + tag = ASSOCIATED_TYPES[child_type].read_from(buf, with_type=False, with_name=False) tag.name = child_name payload.append(tag) return CompoundNBT(payload, name=name) @@ -1133,29 +1159,50 @@ def __iter__(self): for tag in self.payload: yield tag.name, tag + @override def __repr__(self) -> str: """Get a string representation of the CompoundNBT tag.""" if self.name: - return f"{self.__class__.__name__}[{self.name!r}]({dict(self)})" - return f"{self.__class__.__name__}({dict(self)})" - - def to_object(self) -> Mapping[str, Mapping[str, PayloadType]]: + return f"{type(self).__name__}[{self.name!r}]({dict(self)})" + return f"{type(self).__name__}({dict(self)})" + + @override + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> ( + Mapping[str, PayloadType] + | Mapping[str, Mapping[str, PayloadType]] + | tuple[ + Mapping[str, PayloadType] | Mapping[str, Mapping[str, PayloadType]], + Mapping[str, FromObjectSchema], + ] + ): """Convert the CompoundNBT tag to a python object. - :return: A dictionary containing the name and the dictionary of tags. If the tag has no name, the dictionary - will be returned directly. + :param include_schema: Whether to return a schema describing the types of the original tag and its children. + :param include_name: Whether to include the name of the tag in the output. + If the tag has no name, the name will be set to "". + + :return: Either : + - A dictionary containing the payload of the tag. (default) + - A dictionary containing the name associated with a dictionary containing the payload of the tag. + - A tuple which includes one of the above and a dictionary of schemas describing the types of the original tag. """ - result = {} - for tag in self.payload: - if tag.name in result: - raise ValueError(f"Duplicate tag name {tag.name!r} in the compound.") - if tag.name == "": - raise ValueError("All tags in a compound must have a name.") - result.update(cast("dict[str, PayloadType]", tag.to_object())) - if self.name: - return {self.name: result} + result = {tag.name: tag.to_object() for tag in self.payload} + result = cast(Mapping[str, PayloadType], result) + result = result if not include_name else {self.name: result} + if include_schema: + subschemas = { + tag.name: cast( + Tuple[PayloadType, FromObjectSchema], + tag.to_object(include_schema=True), + )[1] + for tag in self.payload + } + return result, subschemas return result + @override def __eq__(self, other: object) -> bool: """Check equality between two CompoundNBT tags. @@ -1169,24 +1216,30 @@ def __eq__(self, other: object) -> bool: # The order of the tags is not guaranteed if not isinstance(other, NBTag): raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") - if self.name != other.name or self.TYPE != other.TYPE: + if type(self) is not type(other): + return False + if self.name != other.name: return False - if not isinstance(other, self.__class__): # pragma: no cover - return False # Should not happen if nobody messes with the TYPE attribute + other = cast(CompoundNBT, other) if len(self.payload) != len(other.payload): return False return all(tag in other.payload for tag in self.payload) + @property + @override + def value(self) -> dict[str, PayloadType]: + """Get the dictionary of tags in the CompoundNBT tag.""" + return {tag.name: tag.value for tag in self.payload} + class IntArrayNBT(NBTag): """NBT tag representing an array of integers. The length of the array is stored as a signed 32-bit integer.""" - TYPE = NBTagType.INT_ARRAY - __slots__ = () payload: list[int] + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the IntArrayNBT tag to the buffer. @@ -1208,6 +1261,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) for i in self.payload: IntNBT(i).write_to(buf, with_name=False, with_type=False) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntArrayNBT: """Read the IntArrayNBT tag from the buffer. @@ -1224,36 +1278,28 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) raise TypeError(f"Expected an INT_ARRAY tag, but found a different tag ({tag_type}).") length = IntNBT.read_from(buf, with_type=False, with_name=False).value try: - payload = [IntNBT.read_from(buf, with_type == NBTagType.INT, with_name=False).value for _ in range(length)] - except IOError: + payload = [IntNBT.read_from(buf, with_type is NBTagType.INT, with_name=False).value for _ in range(length)] + except IOError as exc: raise IOError( "Buffer does not contain enough data to read the entire integer array. (Incomplete data)" - ) from None + ) from exc return IntArrayNBT(payload, name=name) + @override def __repr__(self) -> str: """Get a string representation of the IntArrayNBT tag.""" if self.name: - return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" if len(self.payload) < 8: - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" - return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" - def __iter__(self): + def __iter__(self) -> Iterator[int]: """Iterate over the integers in the array.""" yield from self.payload - def to_object(self) -> Mapping[str, list[int]] | list[int]: - """Convert the IntArrayNBT tag to a python object. - - :return: A dictionary containing the name and the list of integers. If the tag has no name, the list will be - returned directly. - """ - if self.name: - return {self.name: self.payload} - return self.payload - @property + @override def value(self) -> list[int]: """Get the list of integers in the IntArrayNBT tag.""" return self.payload @@ -1262,10 +1308,9 @@ def value(self) -> list[int]: class LongArrayNBT(IntArrayNBT): """NBT tag representing an array of longs. The length of the array is stored as a signed 32-bit integer.""" - TYPE = NBTagType.LONG_ARRAY - __slots__ = () + @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: """Write the LongArrayNBT tag to the buffer. @@ -1287,6 +1332,7 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) for i in self.payload: LongNBT(i).write_to(buf, with_name=False, with_type=False) + @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongArrayNBT: """Read the LongArrayNBT tag from the buffer. @@ -1305,11 +1351,49 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) try: payload = [LongNBT.read_from(buf, with_type=False, with_name=False).payload for _ in range(length)] - except IOError: + except IOError as exc: raise IOError( "Buffer does not contain enough data to read the entire long array. (Incomplete data)" - ) from None + ) from exc return LongArrayNBT(payload, name=name) # endregion + +# region: NBT Associated Types +ASSOCIATED_TYPES: dict[NBTagType, type[NBTag]] = { + NBTagType.END: EndNBT, + NBTagType.BYTE: ByteNBT, + NBTagType.SHORT: ShortNBT, + NBTagType.INT: IntNBT, + NBTagType.LONG: LongNBT, + NBTagType.FLOAT: FloatNBT, + NBTagType.DOUBLE: DoubleNBT, + NBTagType.BYTE_ARRAY: ByteArrayNBT, + NBTagType.STRING: StringNBT, + NBTagType.LIST: ListNBT, + NBTagType.COMPOUND: CompoundNBT, + NBTagType.INT_ARRAY: IntArrayNBT, + NBTagType.LONG_ARRAY: LongArrayNBT, +} + + +def _get_tag_type(tag: NBTag | type[NBTag]) -> NBTagType: + """Get the tag type of an NBTag object or class. + + :param tag: The tag to get the type of. + + :return: The tag type of the tag. + """ + cls = tag if isinstance(tag, type) else type(tag) + + if cls is NBTag: + return NBTagType.COMPOUND + for tag_type, tag_cls in ASSOCIATED_TYPES.items(): + if cls is tag_cls: + return tag_type + + raise ValueError(f"Unknown tag type {cls}.") # pragma: no cover + + +# endregion diff --git a/tests/mcproto/types/test_nbt.py b/tests/mcproto/types/test_nbt.py index 18e5b37c..f68f6e16 100644 --- a/tests/mcproto/types/test_nbt.py +++ b/tests/mcproto/types/test_nbt.py @@ -1,6 +1,8 @@ from __future__ import annotations import struct +from typing import Any, Dict, List, cast +from typing_extensions import override import pytest @@ -22,6 +24,7 @@ PayloadType, ShortNBT, StringNBT, + NBTagConvertible, ) # region EndNBT @@ -41,7 +44,7 @@ def test_serialize_deserialize_end(): assert buffer == bytearray.fromhex("00") buffer = Buffer(bytearray.fromhex("00")) - assert NBTag.deserialize(buffer).TYPE == NBTagType.END + assert EndNBT.deserialize(buffer) == EndNBT() # endregion @@ -86,9 +89,21 @@ def test_serialize_deserialize_end(): (ByteArrayNBT, b"", bytearray.fromhex("07 00 00 00 00")), (ByteArrayNBT, b"\x00", bytearray.fromhex("07 00 00 00 01") + b"\x00"), (ByteArrayNBT, b"\x00\x01", bytearray.fromhex("07 00 00 00 02") + b"\x00\x01"), - (ByteArrayNBT, b"\x00\x01\x02", bytearray.fromhex("07 00 00 00 03") + b"\x00\x01\x02"), - (ByteArrayNBT, b"\x00\x01\x02\x03", bytearray.fromhex("07 00 00 00 04") + b"\x00\x01\x02\x03"), - (ByteArrayNBT, b"\xFF" * 1024, bytearray.fromhex("07 00 00 04 00") + b"\xFF" * 1024), + ( + ByteArrayNBT, + b"\x00\x01\x02", + bytearray.fromhex("07 00 00 00 03") + b"\x00\x01\x02", + ), + ( + ByteArrayNBT, + b"\x00\x01\x02\x03", + bytearray.fromhex("07 00 00 00 04") + b"\x00\x01\x02\x03", + ), + ( + ByteArrayNBT, + b"\xff" * 1024, + bytearray.fromhex("07 00 00 04 00") + b"\xff" * 1024, + ), ( ByteArrayNBT, bytes((n - 1) * n * 2 % 256 for n in range(256)), @@ -100,7 +115,11 @@ def test_serialize_deserialize_end(): (StringNBT, "&à@é", bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8")), (ListNBT, [], bytearray.fromhex("09 00 00 00 00 00")), (ListNBT, [ByteNBT(0)], bytearray.fromhex("09 01 00 00 00 01 00")), - (ListNBT, [ShortNBT(127), ShortNBT(256)], bytearray.fromhex("09 02 00 00 00 02 00 7F 01 00")), + ( + ListNBT, + [ShortNBT(127), ShortNBT(256)], + bytearray.fromhex("09 02 00 00 00 02 00 7F 01 00"), + ), ( ListNBT, [ListNBT([ByteNBT(0)]), ListNBT([IntNBT(256)])], @@ -124,7 +143,10 @@ def test_serialize_deserialize_end(): ), ( CompoundNBT, - [CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test"), IntNBT(-1, "Int 2")], + [ + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test"), + IntNBT(-1, "Int 2"), + ], bytearray.fromhex("0A") + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test").serialize() + IntNBT(-1, "Int 2").serialize() @@ -132,15 +154,43 @@ def test_serialize_deserialize_end(): ), (IntArrayNBT, [], bytearray.fromhex("0B 00 00 00 00")), (IntArrayNBT, [0], bytearray.fromhex("0B 00 00 00 01 00 00 00 00")), - (IntArrayNBT, [0, 1], bytearray.fromhex("0B 00 00 00 02 00 00 00 00 00 00 00 01")), - (IntArrayNBT, [1, 2, 3], bytearray.fromhex("0B 00 00 00 03 00 00 00 01 00 00 00 02 00 00 00 03")), + ( + IntArrayNBT, + [0, 1], + bytearray.fromhex("0B 00 00 00 02 00 00 00 00 00 00 00 01"), + ), + ( + IntArrayNBT, + [1, 2, 3], + bytearray.fromhex("0B 00 00 00 03 00 00 00 01 00 00 00 02 00 00 00 03"), + ), (IntArrayNBT, [(1 << 31) - 1], bytearray.fromhex("0B 00 00 00 01 7F FF FF FF")), - (IntArrayNBT, [(1 << 31) - 1, (1 << 31) - 2], bytearray.fromhex("0B 00 00 00 02 7F FF FF FF 7F FF FF FE")), - (IntArrayNBT, [-1, -2, -3], bytearray.fromhex("0B 00 00 00 03 FF FF FF FF FF FF FF FE FF FF FF FD")), - (IntArrayNBT, [12] * 1024, bytearray.fromhex("0B 00 00 04 00") + b"\x00\x00\x00\x0C" * 1024), + ( + IntArrayNBT, + [(1 << 31) - 1, (1 << 31) - 2], + bytearray.fromhex("0B 00 00 00 02 7F FF FF FF 7F FF FF FE"), + ), + ( + IntArrayNBT, + [-1, -2, -3], + bytearray.fromhex("0B 00 00 00 03 FF FF FF FF FF FF FF FE FF FF FF FD"), + ), + ( + IntArrayNBT, + [12] * 1024, + bytearray.fromhex("0B 00 00 04 00") + b"\x00\x00\x00\x0c" * 1024, + ), (LongArrayNBT, [], bytearray.fromhex("0C 00 00 00 00")), - (LongArrayNBT, [0], bytearray.fromhex("0C 00 00 00 01 00 00 00 00 00 00 00 00")), - (LongArrayNBT, [0, 1], bytearray.fromhex("0C 00 00 00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01")), + ( + LongArrayNBT, + [0], + bytearray.fromhex("0C 00 00 00 01 00 00 00 00 00 00 00 00"), + ), + ( + LongArrayNBT, + [0, 1], + bytearray.fromhex("0C 00 00 00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01"), + ), ( LongArrayNBT, [1, 2, 3], @@ -148,7 +198,11 @@ def test_serialize_deserialize_end(): "0C 00 00 00 03 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03" ), ), - (LongArrayNBT, [(1 << 63) - 1], bytearray.fromhex("0C 00 00 00 01 7F FF FF FF FF FF FF FF")), + ( + LongArrayNBT, + [(1 << 63) - 1], + bytearray.fromhex("0C 00 00 00 01 7F FF FF FF FF FF FF FF"), + ), ( LongArrayNBT, [(1 << 63) - 1, (1 << 63) - 2], @@ -161,7 +215,11 @@ def test_serialize_deserialize_end(): "0C 00 00 00 03 FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FE FF FF FF FF FF FF FF FD" ), ), - (LongArrayNBT, [12] * 1024, bytearray.fromhex("0C 00 00 04 00") + b"\x00\x00\x00\x00\x00\x00\x00\x0C" * 1024), + ( + LongArrayNBT, + [12] * 1024, + bytearray.fromhex("0C 00 00 04 00") + b"\x00\x00\x00\x00\x00\x00\x00\x0c" * 1024, + ), ], ) def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType, expected_bytes: bytes): @@ -193,33 +251,108 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType @pytest.mark.parametrize( ("nbt_class", "value", "name", "expected_bytes"), [ - (ByteNBT, 0, "test", bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("00")), - (ByteNBT, 1, "a", bytearray.fromhex("01") + b"\x00\x01a" + bytearray.fromhex("01")), - (ByteNBT, 127, "&à@é", bytearray.fromhex("01 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F")), - (ByteNBT, -128, "test", bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("80")), - (ByteNBT, 12, "a" * 100, bytearray.fromhex("01") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("0C")), - (ShortNBT, 0, "test", bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("00 00")), - (ShortNBT, 1, "a", bytearray.fromhex("02") + b"\x00\x01a" + bytearray.fromhex("00 01")), - (ShortNBT, 32767, "&à@é", bytearray.fromhex("02 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF")), - (ShortNBT, -32768, "test", bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("80 00")), - (ShortNBT, 12, "a" * 100, bytearray.fromhex("02") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 0C")), - (IntNBT, 0, "test", bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), - (IntNBT, 1, "a", bytearray.fromhex("03") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01")), + ( + ByteNBT, + 0, + "test", + bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("00"), + ), + ( + ByteNBT, + 1, + "a", + bytearray.fromhex("01") + b"\x00\x01a" + bytearray.fromhex("01"), + ), + ( + ByteNBT, + 127, + "&à@é", + bytearray.fromhex("01 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F"), + ), + ( + ByteNBT, + -128, + "test", + bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("80"), + ), + ( + ByteNBT, + 12, + "a" * 100, + bytearray.fromhex("01") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("0C"), + ), + ( + ShortNBT, + 0, + "test", + bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("00 00"), + ), + ( + ShortNBT, + 1, + "a", + bytearray.fromhex("02") + b"\x00\x01a" + bytearray.fromhex("00 01"), + ), + ( + ShortNBT, + 32767, + "&à@é", + bytearray.fromhex("02 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF"), + ), + ( + ShortNBT, + -32768, + "test", + bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("80 00"), + ), + ( + ShortNBT, + 12, + "a" * 100, + bytearray.fromhex("02") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 0C"), + ), + ( + IntNBT, + 0, + "test", + bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), + ( + IntNBT, + 1, + "a", + bytearray.fromhex("03") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01"), + ), ( IntNBT, 2147483647, "&à@é", bytearray.fromhex("03 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF"), ), - (IntNBT, -2147483648, "test", bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00")), + ( + IntNBT, + -2147483648, + "test", + bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00"), + ), ( IntNBT, 12, "a" * 100, bytearray.fromhex("03") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 0C"), ), - (LongNBT, 0, "test", bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00 00 00 00")), - (LongNBT, 1, "a", bytearray.fromhex("04") + b"\x00\x01a" + bytearray.fromhex("00 00 00 00 00 00 00 01")), + ( + LongNBT, + 0, + "test", + bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00 00 00 00"), + ), + ( + LongNBT, + 1, + "a", + bytearray.fromhex("04") + b"\x00\x01a" + bytearray.fromhex("00 00 00 00 00 00 00 01"), + ), ( LongNBT, (1 << 63) - 1, @@ -238,25 +371,60 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType "a" * 100, bytearray.fromhex("04") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 00 00 00 00 0C"), ), - (FloatNBT, 1.0, "test", bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 1.0))), - (FloatNBT, 3.14, "a", bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 3.14))), + ( + FloatNBT, + 1.0, + "test", + bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 1.0)), + ), + ( + FloatNBT, + 3.14, + "a", + bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 3.14)), + ), ( FloatNBT, -1.0, "&à@é", bytearray.fromhex("05 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0)), ), - (FloatNBT, 12.0, "test", bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 12.0))), - (DoubleNBT, 1.0, "test", bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 1.0))), - (DoubleNBT, 3.14, "a", bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 3.14))), + ( + FloatNBT, + 12.0, + "test", + bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 12.0)), + ), + ( + DoubleNBT, + 1.0, + "test", + bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 1.0)), + ), + ( + DoubleNBT, + 3.14, + "a", + bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 3.14)), + ), ( DoubleNBT, -1.0, "&à@é", bytearray.fromhex("06 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0)), ), - (DoubleNBT, 12.0, "test", bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 12.0))), - (ByteArrayNBT, b"", "test", bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + DoubleNBT, + 12.0, + "test", + bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 12.0)), + ), + ( + ByteArrayNBT, + b"", + "test", + bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), ( ByteArrayNBT, b"\x00", @@ -277,12 +445,22 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType ), ( ByteArrayNBT, - b"\xFF" * 1024, + b"\xff" * 1024, "a" * 100, - bytearray.fromhex("07") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 04 00") + b"\xFF" * 1024, + bytearray.fromhex("07") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 04 00") + b"\xff" * 1024, + ), + ( + StringNBT, + "", + "test", + bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 00"), + ), + ( + StringNBT, + "test", + "a", + bytearray.fromhex("08") + b"\x00\x01a" + bytearray.fromhex("00 04") + b"test", ), - (StringNBT, "", "test", bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 00")), - (StringNBT, "test", "a", bytearray.fromhex("08") + b"\x00\x01a" + bytearray.fromhex("00 04") + b"test"), ( StringNBT, "a" * 100, @@ -295,7 +473,12 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType "test", bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 06") + bytes("&à@é", "utf-8"), ), - (ListNBT, [], "test", bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00")), + ( + ListNBT, + [], + "test", + bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00"), + ), ( ListNBT, [ByteNBT(-1)], @@ -316,7 +499,12 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType + b"\x00\x01a" + bytearray.fromhex("09 00 00 00 02 01 00 00 00 01 FF 03 00 00 00 01 00 00 01 00"), ), - (CompoundNBT, [], "test", bytearray.fromhex("0A") + b"\x00\x04test" + bytearray.fromhex("00")), + ( + CompoundNBT, + [], + "test", + bytearray.fromhex("0A") + b"\x00\x04test" + bytearray.fromhex("00"), + ), ( CompoundNBT, [ByteNBT(0, name="Byte")], @@ -348,7 +536,12 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType "test", bytearray.fromhex("0A") + b"\x00\x04test" + ListNBT([ByteNBT(0)], name="List").serialize() + b"\x00", ), - (IntArrayNBT, [], "test", bytearray.fromhex("0B") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + IntArrayNBT, + [], + "test", + bytearray.fromhex("0B") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), ( IntArrayNBT, [0], @@ -381,9 +574,14 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 01") - + b"\x7F\xFF\xFF\xFF", + + b"\x7f\xff\xff\xff", + ), + ( + LongArrayNBT, + [], + "test", + bytearray.fromhex("0C") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), ), - (LongArrayNBT, [], "test", bytearray.fromhex("0C") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), ( LongArrayNBT, [0], @@ -419,7 +617,7 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 64") - + b"\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF" * 100, + + b"\x7f\xff\xff\xff\xff\xff\xff\xff" * 100, ), ], ) @@ -471,9 +669,6 @@ def test_serialize_deserialize_numerical_fail(nbt_class: type[NBTag], size: int, with pytest.raises(OverflowError): nbt_class(-(1 << (size - 1)) - 1).serialize(with_name=False) - with pytest.raises(ValueError): # No name - nbt_class(0, "").serialize() # without with_name=False - # Deserialization buffer = Buffer(bytearray([tag.value + 1] + [0] * (size // 8))) with pytest.raises(TypeError): # Tries to read a nbt_class, but it's one higher @@ -495,9 +690,6 @@ def test_serialize_deserialize_numerical_fail(nbt_class: type[NBTag], size: int, def test_serialize_deserialize_float_fail(): """Test serialization/deserialization of NBT FLOAT tag with invalid value.""" - with pytest.raises(ValueError): - FloatNBT(0, 0).serialize() # type:ignore - with pytest.raises(struct.error): FloatNBT("test").serialize(with_name=False) @@ -524,9 +716,6 @@ def test_serialize_deserialize_float_fail(): def test_serialize_deserialize_double_fail(): """Test serialization/deserialization of NBT DOUBLE tag with invalid value.""" - with pytest.raises(ValueError): - DoubleNBT(0, 0).serialize() # type: ignore - with pytest.raises(struct.error): DoubleNBT("test").serialize(with_name=False) @@ -547,12 +736,6 @@ def test_serialize_deserialize_double_fail(): def test_serialize_deserialize_bytearray_fail(): """Test serialization/deserialization of NBT BYTEARRAY tag with invalid value.""" - with pytest.raises(ValueError): - ByteArrayNBT([], 0).serialize() # type:ignore - - with pytest.raises(ValueError): - ByteArrayNBT(b"test", "").serialize() - # Deserialization buffer = Buffer(bytearray([0x01] + [0] * 4)) with pytest.raises(TypeError): # Tries to read a ByteArrayNBT, but it's a ByteNBT @@ -585,12 +768,6 @@ def test_serialize_deserialize_bytearray_fail(): def test_serialize_deserialize_string_fail(): """Test serialization/deserialization of NBT STRING tag with invalid value.""" - with pytest.raises(ValueError): - StringNBT("", 0).serialize() # type:ignore - - with pytest.raises(ValueError): - StringNBT("test", "").serialize() - # Deserialization buffer = Buffer(bytearray([0x01, 0, 0])) with pytest.raises(TypeError): # Tries to read a StringNBT, but it's a ByteNBT @@ -636,7 +813,7 @@ def test_serialize_deserialize_string_fail(): ([ByteNBT(128), ByteNBT(-1)], OverflowError), # Check for error propagation ], ) -def test_serialize_list_fail(payload, error): +def test_serialize_list_fail(payload: PayloadType, error: type[Exception]): """Test serialization of NBT LIST tag with invalid value.""" with pytest.raises(error): ListNBT(payload, "test").serialize() @@ -671,10 +848,13 @@ def test_deserialize_list_fail(): ([ByteNBT(0, name="hi"), "test"], ValueError), ([ByteNBT(0, name="hi"), None], ValueError), ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags - ([ByteNBT(128, name="Jello"), ByteNBT(-1, name="Bonjour")], OverflowError), # Check for error propagation + ( + [ByteNBT(128, name="Jello"), ByteNBT(-1, name="Bonjour")], + OverflowError, + ), # Check for error propagation ], ) -def test_serialize_compound_fail(payload, error): +def test_serialize_compound_fail(payload: PayloadType, error: type[Exception]): """Test serialization of NBT COMPOUND tag with invalid value.""" with pytest.raises(error): CompoundNBT(payload, "test").serialize() @@ -702,30 +882,40 @@ def test_deseialize_compound_fail(): NBTag.deserialize(buffer) -def test_to_object_compound(): - """Try a few incorrect CompoundNBT.to_object() calls.""" - comp = CompoundNBT([ByteNBT(0, "test"), ByteNBT(1, "test")]) - with pytest.raises(ValueError): - comp.to_object() # Duplicate name +def test_nbtag_deserialize_compound(): + """Test deserialization of NBT COMPOUND tag from the NBTag class.""" + buf = Buffer(bytearray([0x00])) + assert NBTag.deserialize(buf, with_type=False, with_name=False) == CompoundNBT([]) - comp = CompoundNBT([ByteNBT(0), ByteNBT(1)]) - with pytest.raises(ValueError): - comp.to_object() + buf = Buffer(bytearray.fromhex("0A 00 01 61 01 00 01 62 00 00")) + assert NBTag.deserialize(buf) == CompoundNBT([ByteNBT(0, name="b")], name="a") def test_equality_compound(): """Test equality of CompoundNBT.""" - comp1 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") - comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") + comp1 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], + "comp", + ) + comp2 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], + "comp", + ) assert comp1 == comp2 comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2")], "comp") assert comp1 != comp2 - comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], "comp") + comp2 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], + "comp", + ) assert comp1 != comp2 - comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp2") + comp2 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], + "comp2", + ) assert comp1 != comp2 assert comp1 != ByteNBT(0, name="comp") @@ -744,7 +934,7 @@ def test_equality_compound(): ([0, -(1 << 31) - 1], OverflowError), ], ) -def test_serialize_intarray_fail(payload, error): +def test_serialize_intarray_fail(payload: PayloadType, error: type[Exception]): """Test serialization of NBT INTARRAY tag with invalid value.""" with pytest.raises(error): IntArrayNBT(payload, "test").serialize() @@ -781,7 +971,7 @@ def test_deserialize_intarray_fail(): ([0, -(1 << 63) - 1], OverflowError), ], ) -def test_serialize_deserialize_longarray_fail(payload, error): +def test_serialize_deserialize_longarray_fail(payload: PayloadType, error: type[Exception]): """Test serialization/deserialization of NBT LONGARRAY tag with invalid value.""" with pytest.raises(error): LongArrayNBT(payload, "test").serialize() @@ -819,60 +1009,86 @@ def test_nbt_helloworld(): buffer = Buffer(data) expected_object = { - "hello world": { - "name": "Bananrama", - } + "name": "Bananrama", } + expected_schema = {"name": StringNBT} data = CompoundNBT.deserialize(buffer) - assert data == NBTag.from_object(expected_object) + assert data == NBTag.from_object(expected_object, schema=expected_schema, name="hello world") assert data.to_object() == expected_object def test_nbt_bigfile(): """Test serialization/deserialization of a big NBT tag. - Slighly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. + Slightly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. Source data: https://wiki.vg/NBT#Example. """ data = "0a00054c6576656c0400086c6f6e67546573747fffffffffffffff02000973686f7274546573747fff08000a737472696e6754657374002948454c4c4f20574f524c4420544849532049532041205445535420535452494e4720c385c384c39621050009666c6f6174546573743eff1832030007696e74546573747fffffff0a00146e657374656420636f6d706f756e6420746573740a000368616d0800046e616d65000648616d70757305000576616c75653f400000000a00036567670800046e616d6500074567676265727405000576616c75653f00000000000c000f6c6973745465737420286c6f6e672900000005000000000000000b000000000000000c000000000000000d000000000000000e7fffffffffffffff0b000e6c697374546573742028696e7429000000047fffffff7ffffffe7ffffffd7ffffffc0900136c697374546573742028636f6d706f756e64290a000000020800046e616d65000f436f6d706f756e642074616720233004000a637265617465642d6f6e000001265237d58d000800046e616d65000f436f6d706f756e642074616720233104000a637265617465642d6f6e000001265237d58d0001000862797465546573747f07006562797465417272617954657374202874686520666972737420313030302076616c756573206f6620286e2a6e2a3235352b6e2a3729253130302c207374617274696e672077697468206e3d302028302c2036322c2033342c2031362c20382c202e2e2e2929000003e8003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a063005000a646f75626c65546573743efc7b5e00" # noqa: E501 - data = bytearray.fromhex(data) + data = bytes.fromhex(data) buffer = Buffer(data) - expected_object = { - "Level": { - "longTest": 9223372036854775807, - "shortTest": 32767, - "stringTest": "HELLO WORLD THIS IS A TEST STRING ÅÄÖ!", - "floatTest": 0.4982314705848694, - "intTest": 2147483647, - "nested compound test": { - "ham": {"name": "Hampus", "value": 0.75}, - "egg": {"name": "Eggbert", "value": 0.5}, + expected_object = { # Name ! Level + "longTest": 9223372036854775807, + "shortTest": 32767, + "stringTest": "HELLO WORLD THIS IS A TEST STRING ÅÄÖ!", + "floatTest": 0.4982314705848694, + "intTest": 2147483647, + "nested compound test": { + "ham": {"name": "Hampus", "value": 0.75}, + "egg": {"name": "Eggbert", "value": 0.5}, + }, + "listTest (long)": [11, 12, 13, 14, 9223372036854775807], + "listTest (int)": [2147483647, 2147483646, 2147483645, 2147483644], + "listTest (compound)": [ + {"name": "Compound tag #0", "created-on": 1264099775885}, + {"name": "Compound tag #1", "created-on": 1264099775885}, + ], + "byteTest": 127, + "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, " + "starting with n=0 (0, 62, 34, 16, 8, ...))": bytes((n * n * 255 + n * 7) % 100 for n in range(1000)), + "doubleTest": 0.4931287132182315, + } + expected_schema = { + "longTest": LongNBT, + "shortTest": ShortNBT, + "stringTest": StringNBT, + "floatTest": FloatNBT, + "intTest": IntNBT, + "nested compound test": { + "ham": { + "name": StringNBT, + "value": FloatNBT, }, - "listTest (long)": [11, 12, 13, 14, 9223372036854775807], - "listTest (int)": [2147483647, 2147483646, 2147483645, 2147483644], - "listTest (compound)": [ - {"name": "Compound tag #0", "created-on": 1264099775885}, - {"name": "Compound tag #1", "created-on": 1264099775885}, - ], - "byteTest": 127, - "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100" - ", starting with n=0 (0, 62, 34, 16, 8, ...))": bytearray( - (n * n * 255 + n * 7) % 100 for n in range(1000) - ), - "doubleTest": 0.4931287132182315, - } + "egg": { + "name": StringNBT, + "value": FloatNBT, + }, + }, + "listTest (long)": LongArrayNBT, + "listTest (int)": IntArrayNBT, + "listTest (compound)": [ + { + "name": StringNBT, + "created-on": LongNBT, + } + ], + "byteTest": ByteNBT, + "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, " + "starting with n=0 (0, 62, 34, 16, 8, ...))": ByteArrayNBT, + "doubleTest": FloatNBT, } data = CompoundNBT.deserialize(buffer) # print(f"{data=}\n{expected_object=}\n{data.to_object()=}\n{NBTag.from_object(expected_object)=}") - def check_equality(self, other): + def check_equality(self: object, other: object) -> bool: """Check if two objects are equal, with deep epsilon check for floats.""" if type(self) != type(other): return False if isinstance(self, dict): + self = cast(Dict[Any, Any], self) + other = cast(Dict[Any, Any], other) if len(self) != len(other): return False for key in self: @@ -882,16 +1098,18 @@ def check_equality(self, other): return False return True if isinstance(self, list): + self = cast(List[Any], self) + other = cast(List[Any], other) if len(self) != len(other): return False return all(check_equality(self[i], other[i]) for i in range(len(self))) - if isinstance(self, float): + if isinstance(self, float) and isinstance(other, float): return abs(self - other) < 1e-6 if self != other: return False return self == other - assert data == NBTag.from_object(expected_object) + assert data == NBTag.from_object(expected_object, schema=expected_schema, name="Level") assert check_equality(data.to_object(), expected_object) @@ -902,30 +1120,16 @@ def check_equality(self, other): def test_from_object_lst_not_same_type(): """Test from_object with a list that does not have the same type.""" with pytest.raises(TypeError): - NBTag.from_object([ByteNBT(0), IntNBT(0)]) - - -def test_from_object_out_of_bounds(): - """Test from_object with a value that is out of bounds.""" - with pytest.raises(ValueError): - NBTag.from_object({"test": 1 << 63}) - - with pytest.raises(ValueError): - NBTag.from_object({"test": -(1 << 63) - 1}) - - with pytest.raises(ValueError): - NBTag.from_object({"test": [1 << 63]}) - - with pytest.raises(ValueError): - NBTag.from_object({"test": [-(1 << 63) - 1]}) + NBTag.from_object([0, "test"], [IntNBT, StringNBT]) def test_from_object_morecases(): """Test from_object with more edge cases.""" - class CustomType: - def __bytes__(self): - return b"test" + class CustomType(NBTagConvertible): + @override + def to_nbt(self, name: str = "") -> NBTag: + return ByteArrayNBT(b"CustomType", name) assert NBTag.from_object( { @@ -933,60 +1137,48 @@ def __bytes__(self): "bytearray": b"test", # Conversion from bytes "empty_list": [], # Empty list with type EndNBT "empty_compound": {}, # Empty compound - "end_NBTag": None, # Should not be done in practice, would create a broken buffer if serialized - "custom": CustomType(), # Custom type with __bytes__ method - } + "custom": CustomType(), # Custom type with to_nbt method + "recursive_list": [ + [0, 1, 2], + [3, 4, 5], + ], + }, + { + "nbtag": ByteNBT, + "bytearray": ByteArrayNBT, + "empty_list": [], + "empty_compound": {}, + "custom": CustomType, + "recursive_list": [[IntNBT], [ShortNBT]], + }, ) == CompoundNBT( [ # Order is shuffled because the spec does not require a specific order CompoundNBT([], "empty_compound"), ByteArrayNBT(b"test", "bytearray"), - ByteArrayNBT(b"test", "custom"), + ByteArrayNBT(b"CustomType", "custom"), ListNBT([], "empty_list"), ByteNBT(0, "nbtag"), - EndNBT(), + ListNBT( + [ListNBT([IntNBT(0), IntNBT(1), IntNBT(2)]), ListNBT([ShortNBT(3), ShortNBT(4), ShortNBT(5)])], + "recursive_list", + ), ] ) - # Not a valid object - with pytest.raises(TypeError): - NBTag.from_object({"test": object()}) - compound = CompoundNBT.from_object( { "test": ByteNBT(0), "test2": IntNBT(0), }, + { + "test": ByteNBT, + "test2": IntNBT, + }, name="compound", ) - assert compound["test"] == ByteNBT(0, "test") - assert compound["test2"] == IntNBT(0, "test2") - with pytest.raises(KeyError): - compound["test3"] - # Cannot index into a ByteNBT - with pytest.raises(TypeError): - compound["test"][0] # type:ignore - - listnbt = ListNBT.from_object([0, 1, 2], use_int_array=False) - assert listnbt[0] == ByteNBT(0) - assert listnbt[1] == ByteNBT(1) - assert listnbt[2] == ByteNBT(2) - with pytest.raises(IndexError): - listnbt[3] - with pytest.raises(TypeError): - listnbt["hello"] - - assert listnbt[-1] == ByteNBT(2) - assert listnbt[-2] == ByteNBT(1) - assert listnbt[-3] == ByteNBT(0) - - with pytest.raises(TypeError): - listnbt[object()] # type:ignore - - assert listnbt.value == [0, 1, 2] - assert listnbt.to_object() == [0, 1, 2] assert ListNBT([]).value == [] - assert compound.to_object() == {"compound": {"test": 0, "test2": 0}} + assert compound.to_object(include_name=True) == {"compound": {"test": 0, "test2": 0}} assert compound.value == {"test": 0, "test2": 0} assert ListNBT([IntNBT(0)]).value == [0] @@ -1001,21 +1193,86 @@ def __bytes__(self): assert IntArrayNBT([0, 1, 2]).value == [0, 1, 2] assert LongArrayNBT([0, 1, 2, 3]).value == [0, 1, 2, 3] - invalid = ListNBT("Hello", "name") - with pytest.raises(AttributeError): - invalid[0] - invalid = CompoundNBT([ByteNBT(0, "Byte"), "Hi"], "name") - with pytest.raises(AttributeError): - invalid["Byte"] # Attribute error is raised when the structure is incorrectly constructed +@pytest.mark.parametrize( + ("data", "schema", "error", "error_msg"), + [ + # Data is not a list + ({"test": 0}, {"test": [ByteNBT]}, TypeError, "Expected a list, but found a different type."), + # Expected a list of dict, got a list of NBTags for schema + ( + {"test": [1, 0]}, + {"test": [ByteNBT, IntNBT]}, + TypeError, + "Expected a list of lists or dictionaries, but found a different type.", + ), + # Schema and data have different lengths + ( + [[1], [2], [3]], + [[ByteNBT], [IntNBT]], + ValueError, + "The schema and the data must have the same length.", + ), + # schema empty, data is not + ([1], [], ValueError, "The schema is empty, but the data is not."), + # Schema is a dict, data is not + (["test"], {"test": ByteNBT}, TypeError, "Expected a dictionary, but found a different type."), + # Schema is not a dict, list or subclass of NBTagConvertible + ( + ["test"], + "test", + TypeError, + "The schema must be a list, dict or a subclass of either NBTag or NBTagConvertible.", + ), + # Schema contains CompoundNBT or ListNBT instead of a dict or list + ( + {"test": 0}, + CompoundNBT, + ValueError, + "The schema must specify the type of the elements in CompoundNBT and ListNBT tags.", + ), + ( + ["test"], + ListNBT, + ValueError, + "The schema must specify the type of the elements in CompoundNBT and ListNBT tags.", + ), + # The schema specifies a type, but the data is a dict with more than one key + ( + {"test": 0, "test2": 1}, + ByteNBT, + ValueError, + "Expected a dictionary with a single key-value pair.", + ), + # The data is not of the right type to be a payload + ( + {"test": object()}, + ByteNBT, + TypeError, + "Expected a bytes, str, int, float, but found object.", + ), + # The data is a list but not all elements are ints + ( + [0, "test"], + IntArrayNBT, + TypeError, + "Expected a list of integers.", + ), + ], +) +def test_from_object_error(data: Any, schema: Any, error: type[Exception], error_msg: str): + """Test from_object with erroneous data.""" + with pytest.raises(error, match=error_msg): + NBTag.from_object(data, schema) def test_to_object_morecases(): """Test to_object with more edge cases.""" - class CustomType: - def __bytes__(self): - return b"test" + class CustomType(NBTagConvertible): + @override + def to_nbt(self, name: str = "") -> NBTag: + return ByteArrayNBT(b"CustomType", name) assert NBTag.from_object( { @@ -1023,27 +1280,58 @@ def __bytes__(self): "empty_list": [], "empty_compound": {}, "custom": CustomType(), - } - ).to_object() == { - "bytearray": b"test", - "empty_list": [], - "empty_compound": {}, - "custom": b"test", - } - - assert NBTag.to_object(CompoundNBT([])) == {} + "recursive_list": [ + [0, 1, 2], + [3, 4, 5], + ], + "compound_list": [{"test": 0, "test2": 1}, {"test2": 1}], + }, + { + "bytearray": ByteArrayNBT, + "empty_list": [], + "empty_compound": {}, + "custom": CustomType, + "recursive_list": [[IntNBT], [ShortNBT]], + "compound_list": [{"test": ByteNBT, "test2": IntNBT}, {"test2": IntNBT}], + }, + ).to_object(include_schema=True) == ( + { + "bytearray": b"test", + "empty_list": [], + "empty_compound": {}, + "custom": b"CustomType", + "recursive_list": [[0, 1, 2], [3, 4, 5]], + "compound_list": [{"test": 0, "test2": 1}, {"test2": 1}], + }, + { + "bytearray": ByteArrayNBT, + "empty_list": [], + "empty_compound": {}, + "custom": ByteArrayNBT, # After the conversion, the NBT tag is a ByteArrayNBT + "recursive_list": [[IntNBT], [ShortNBT]], + "compound_list": [{"test": ByteNBT, "test2": IntNBT}, {"test2": IntNBT}], + }, + ) - assert EndNBT().to_object() == {} # Does not add anything when doing dict.update assert FloatNBT(0.5).to_object() == 0.5 - assert FloatNBT(0.5, "Hello World").to_object() == {"Hello World": 0.5} + assert FloatNBT(0.5, "Hello World").to_object(include_name=True) == {"Hello World": 0.5} assert ByteArrayNBT(b"test").to_object() == b"test" # Do not add name when there is no name assert StringNBT("test").to_object() == "test" - assert StringNBT("test", "name").to_object() == {"name": "test"} + assert StringNBT("test", "name").to_object(include_name=True) == {"name": "test"} assert ListNBT([ByteNBT(0), ByteNBT(1)]).to_object() == [0, 1] - assert ListNBT([ByteNBT(0), ByteNBT(1)], "name").to_object() == {"name": [0, 1]} + assert ListNBT([ByteNBT(0), ByteNBT(1)], "name").to_object(include_name=True) == {"name": [0, 1]} assert IntArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] assert LongArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] + with pytest.raises(TypeError): + NBTag.to_object(CompoundNBT([])) + + with pytest.raises(TypeError): + ListNBT([CompoundNBT([]), ListNBT([])]).to_object(include_schema=True) + + with pytest.raises(TypeError): + ListNBT([IntNBT(0), ShortNBT(0)]).to_object(include_schema=True) + def test_data_conversions(): """Test data conversions using the built-in functions.""" @@ -1063,11 +1351,32 @@ def test_data_conversions(): def test_init_nbtag_directly(): """Test initializing NBTag directly.""" with pytest.raises(TypeError): - NBTag(0) - with pytest.raises(TypeError): - NBTag(0, "test") + NBTag(0) # type: ignore # I know, that's what I'm testing + + +@pytest.mark.parametrize( + ("buffer_content", "tag_type"), + [ + ("01", EndNBT), + ("00 00", ByteNBT), + ("01 0000", ShortNBT), + ("02 00000000", IntNBT), + ("03 0000000000000000", LongNBT), + ("04 3F800000", FloatNBT), + ("05 3FF999999999999A", DoubleNBT), + ("06 00", ByteArrayNBT), + ("07 00", StringNBT), + ("08 00", ListNBT), + ("09 00", CompoundNBT), + ("0A 00", IntArrayNBT), + ("0B 00", LongArrayNBT), + ], +) +def test_wrong_type(buffer_content: str, tag_type: type[NBTag]): + """Test read_from with wrong tag type in the buffer.""" + buffer = Buffer(bytearray.fromhex(buffer_content)) with pytest.raises(TypeError): - NBTag(0, name="test") + tag_type.read_from(buffer, with_name=False) # endregion From 63093023eff5c837b5e12391f85731234c6711dc Mon Sep 17 00:00:00 2001 From: Alexis Rossfelder Date: Tue, 30 Apr 2024 20:06:30 +0200 Subject: [PATCH 3/3] Rewrite the `NBTag.from_object` method Remove useless docstrings with @override Fix formatting for Sphinx docstrings Return NotImplemented instead of raising the exception Make use of StructFormat to read/write the numeric types --- mcproto/types/nbt.py | 606 ++++++++++---------------------- tests/mcproto/types/test_nbt.py | 88 +++-- 2 files changed, 228 insertions(+), 466 deletions(-) diff --git a/mcproto/types/nbt.py b/mcproto/types/nbt.py index 8932b5cf..cc529a86 100644 --- a/mcproto/types/nbt.py +++ b/mcproto/types/nbt.py @@ -32,100 +32,99 @@ """ Implementation of the NBT (Named Binary Tag) format used in Minecraft as described in the NBT specification -(:seealso: :class:`NBTagType`). -""" -# region NBT Specification +Source : `Minecraft NBT Spec `_ -class NBTagType(IntEnum): - """Enumeration of the different types of NBT tags. +Named Binary Tag specification - Source : https://web.archive.org/web/20110723210920/http://www.minecraft.net/docs/NBT.txt +NBT (Named Binary Tag) is a tag based binary format designed to carry large amounts of binary data with smaller +amounts of additional data. +An NBT file consists of a single GZIPped Named Tag of type TAG_Compound. - Named Binary Tag specification +A Named Tag has the following format: - NBT (Named Binary Tag) is a tag based binary format designed to carry large amounts of binary data with smaller - amounts of additional data. - An NBT file consists of a single GZIPped Named Tag of type TAG_Compound. + byte tagType + TAG_String name + [payload] - A Named Tag has the following format: +The tagType is a single byte defining the contents of the payload of the tag. - byte tagType - TAG_String name - [payload] +The name is a descriptive name, and can be anything (eg "cat", "banana", "Hello World!"). It has nothing to do with +the tagType. +The purpose for this name is to name tags so parsing is easier and can be made to only look for certain recognized +tag names. +Exception: If tagType is TAG_End, the name is skipped and assumed to be "". - The tagType is a single byte defining the contents of the payload of the tag. +The [payload] varies by tagType. - The name is a descriptive name, and can be anything (eg "cat", "banana", "Hello World!"). It has nothing to do with - the tagType. - The purpose for this name is to name tags so parsing is easier and can be made to only look for certain recognized - tag names. - Exception: If tagType is TAG_End, the name is skipped and assumed to be "". +Note that ONLY Named Tags carry the name and tagType data. Explicitly identified Tags (such as TAG_String above) +only contains the payload. - The [payload] varies by tagType. +.. seealso:: :class:`NBTagType` - Note that ONLY Named Tags carry the name and tagType data. Explicitly identified Tags (such as TAG_String above) - only contains the payload. +""" +# region NBT Specification - The tag types and respective payloads are: +class NBTagType(IntEnum): + """Enumeration of the different types of NBT tags. - TYPE: 0 NAME: TAG_End - Payload: None. - Note: This tag is used to mark the end of a list. - Cannot be named! If type 0 appears where a Named Tag is expected, the name is assumed to be "". - (In other words, this Tag is always just a single 0 byte when named, and nothing in all other cases) + TYPE: 0 NAME: TAG_End + Payload: None. + Note: This tag is used to mark the end of a list. + Cannot be named! If type 0 appears where a Named Tag is expected, the name is assumed to be "". + (In other words, this Tag is always just a single 0 byte when named, and nothing in all other cases) - TYPE: 1 NAME: TAG_Byte - Payload: A single signed byte (8 bits) + TYPE: 1 NAME: TAG_Byte + Payload: A single signed byte (8 bits) - TYPE: 2 NAME: TAG_Short - Payload: A signed short (16 bits, big endian) + TYPE: 2 NAME: TAG_Short + Payload: A signed short (16 bits, big endian) - TYPE: 3 NAME: TAG_Int - Payload: A signed short (32 bits, big endian) + TYPE: 3 NAME: TAG_Int + Payload: A signed short (32 bits, big endian) - TYPE: 4 NAME: TAG_Long - Payload: A signed long (64 bits, big endian) + TYPE: 4 NAME: TAG_Long + Payload: A signed long (64 bits, big endian) - TYPE: 5 NAME: TAG_Float - Payload: A floating point value (32 bits, big endian, IEEE 754-2008, binary32) + TYPE: 5 NAME: TAG_Float + Payload: A floating point value (32 bits, big endian, IEEE 754-2008, binary32) - TYPE: 6 NAME: TAG_Double - Payload: A floating point value (64 bits, big endian, IEEE 754-2008, binary64) + TYPE: 6 NAME: TAG_Double + Payload: A floating point value (64 bits, big endian, IEEE 754-2008, binary64) - TYPE: 7 NAME: TAG_Byte_Array - Payload: TAG_Int length - An array of bytes of unspecified format. The length of this array is bytes + TYPE: 7 NAME: TAG_Byte_Array + Payload: TAG_Int length + An array of bytes of unspecified format. The length of this array is bytes - TYPE: 8 NAME: TAG_String - Payload: TAG_Short length - An array of bytes defining a string in UTF-8 format. The length of this array is bytes + TYPE: 8 NAME: TAG_String + Payload: TAG_Short length + An array of bytes defining a string in UTF-8 format. The length of this array is bytes - TYPE: 9 NAME: TAG_List - Payload: TAG_Byte tagId - TAG_Int length - A sequential list of Tags (not Named Tags), of type . The length of this array is - Tags - Notes: All tags share the same type. + TYPE: 9 NAME: TAG_List + Payload: TAG_Byte tagId + TAG_Int length + A sequential list of Tags (not Named Tags), of type . The length of this array is + Tags + Notes: All tags share the same type. - TYPE: 10 NAME: TAG_Compound - Payload: A sequential list of Named Tags. This array keeps going until a TAG_End is found. - TAG_End end - Notes: If there's a nested TAG_Compound within this tag, that one will also have a TAG_End, so simply reading - until the next TAG_End will not work. - The names of the named tags have to be unique within each TAG_Compound - The order of the tags is not guaranteed. + TYPE: 10 NAME: TAG_Compound + Payload: A sequential list of Named Tags. This array keeps going until a TAG_End is found. + TAG_End end + Notes: If there's a nested TAG_Compound within this tag, that one will also have a TAG_End, so simply reading + until the next TAG_End will not work. + The names of the named tags have to be unique within each TAG_Compound + The order of the tags is not guaranteed. - // NEW TAGS - TYPE: 11 NAME: TAG_Int_Array - Payload: TAG_Int length - An array of integers. The length of this array is integers + // NEW TAGS (not in the original spec) + TYPE: 11 NAME: TAG_Int_Array + Payload: TAG_Int length + An array of integers. The length of this array is integers - TYPE: 12 NAME: TAG_Long_Array - Payload: TAG_Int length - An array of longs. The length of this array is longs + TYPE: 12 NAME: TAG_Long_Array + Payload: TAG_Int length + An array of longs. The length of this array is longs """ @@ -209,8 +208,8 @@ def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: :param with_type: Whether to include the type of the tag in the serialization. :param with_name: Whether to include the name of the tag in the serialization. + .. note:: These parameters only control the first level of serialization. - :note: These parameters only control the first level of serialization. :return: The buffer containing the serialized NBT tag. """ buf = Buffer() @@ -218,6 +217,12 @@ def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: return buf def _write_header(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the header of the NBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + """ if with_type: tag_type = _get_tag_type(self) buf.write_value(StructFormat.BYTE, tag_type.value) @@ -267,8 +272,8 @@ def _read_header(cls, buf: Buffer, read_type: bool = True, with_name: bool = Tru :return: A tuple containing the name and the tag type. - :note: It is possible that this function reads nothing from the buffer if both with_name and read_type are set - to False. + .. note:: It is possible that this function reads nothing from the buffer if both with_name and read_type are + set to False. """ if read_type: try: @@ -306,87 +311,112 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) @staticmethod def from_object(data: FromObjectType, schema: FromObjectSchema, name: str = "") -> NBTag: - """Create an NBT tag from a dictionary. + """Create an NBT tag from a python object and a schema. - :param data: The dictionary to create the NBT tag from. + :param data: The python object to create the NBT tag from. :param schema: The schema used to create the NBT tags. + :param name: The name of the NBT tag. - If the schema is a list, the data must be a list and the schema must either contain a single element - representing the type of the elements in the list or multiple dictionaries or lists representing the types - of the elements in the list since they are the only types that have a variable type. - - Example: - ```python - schema = [IntNBT] - data = [1, 2, 3] - schema = [[IntNBT], [StringNBT]] - data = [[1, 2, 3], ["a", "b", "c"]] - ``` + The schema is a description of the types of the data in the python object. + The schema can be a subclass of NBTag (e.g. IntNBT, StringNBT, CompoundNBT, etc.), a dictionary, a list, a + tuple, or an object that has a `to_nbt` method. - If the schema is a dictionary, the data must be a dictionary and the schema must contain the keys and the - types of the values in the dictionary. + Example of schema: + schema = { + "string": StringNBT, + "list_of_floats": [FloatNBT], + "list_of_compounds": [{ + "key": StringNBT, + "value": IntNBT, + }], + "list_of_lists": [[IntNBT], [StringNBT]], + } - Example: - ```python - schema = {"key": IntNBT} - data = {"key": 1} - ``` + This would be translated into a CompoundNBT - If the schema is a subclass of NBTag, the data will be passed to the constructor of the schema. - If the schema is not a list, dictionary or subclass of NBTag, the data will be converted to an NBT tag - using the `to_nbt` method of the data. + :return: The NBT tag created from the python object. + """ + # Case 0 : schema is an object with a `to_nbt` method (could be a subclass of NBTag for all we know, as long + # as the data is an instance of the schema it will work) + if isinstance(schema, type) and hasattr(schema, "to_nbt") and isinstance(data, schema): + return data.to_nbt(name=name) - :param name: The name of the NBT tag. + # Case 1 : schema is a NBTag subclass + if isinstance(schema, type) and issubclass(schema, NBTag): + if schema in (CompoundNBT, ListNBT): + raise ValueError("Use a list or a dictionary in the schema to create a CompoundNBT or a ListNBT.") + # Check if the data contains the name (if it is a dictionary) + if isinstance(data, dict): + if len(data) != 1: + raise ValueError("Expected a dictionary with a single key-value pair.") + # We also check if the name isn't already set + if name: + raise ValueError("The name is already set.") + key, value = next(iter(data.items())) + # Recursive call to go to the next part + return NBTag.from_object(value, schema, name=key) + # Else we check if the data can be a payload for the tag + if not isinstance(data, (bytes, str, int, float, list)): + raise TypeError(f"Expected a bytes, str, int, float, but found {type(data).__name__}.") + # Check if the data is a list of integers + if isinstance(data, list) and not all(isinstance(item, int) for item in data): + raise TypeError("Expected a list of integers.") + data = cast(Union[bytes, str, int, float, List[int]], data) + # Create the tag with the data and the name + return schema(data, name=name) + + # Sanity check : Verify that all type schemas have been handled + if not isinstance(schema, (list, tuple, dict)): + raise TypeError( + "The schema must be a list, dict, a subclass of NBTag or an object with a `to_nbt` method." + ) - :return: The NBT tag created from the dictionary. - """ - if isinstance(schema, (list, tuple)): - if not isinstance(data, list): - raise TypeError("Expected a list, but found a different type.") - payload: list[NBTag] = [] - if len(schema) > 1: - if not all(isinstance(item, (list, dict)) for item in schema): - raise TypeError("Expected a list of lists or dictionaries, but found a different type.") - if len(schema) != len(data): - raise ValueError("The schema and the data must have the same length.") - for item, sub_schema in zip(data, schema): - payload.append(NBTag.from_object(item, sub_schema)) - else: - if len(schema) == 0 and len(data) > 0: - raise ValueError("The schema is empty, but the data is not.") - if len(schema) == 0: - return ListNBT([], name=name) - - schema = schema[0] - for item in data: - payload.append(NBTag.from_object(item, schema)) - return ListNBT(payload, name=name) + # Case 2 : schema is a dictionary if isinstance(schema, dict): + # We can unpack the dictionary and create a CompoundNBT tag if not isinstance(data, dict): raise TypeError("Expected a dictionary, but found a different type.") + # Iterate over the dictionary payload: list[NBTag] = [] for key, value in data.items(): + # Recursive calls payload.append(NBTag.from_object(value, schema[key], name=key)) + # Finally we assign the payload and the name to the CompoundNBT tag return CompoundNBT(payload, name=name) - if not isinstance(schema, type) or not issubclass(schema, (NBTag, NBTagConvertible)): # type: ignore - raise TypeError("The schema must be a list, dict or a subclass of either NBTag or NBTagConvertible.") - if isinstance(data, schema): - return data.to_nbt(name=name) - schema = cast(Type[NBTag], schema) # Last option - if issubclass(schema, (CompoundNBT, ListNBT)): - raise ValueError("The schema must specify the type of the elements in CompoundNBT and ListNBT tags.") - if isinstance(data, dict): - if len(data) != 1: - raise ValueError("Expected a dictionary with a single key-value pair.") - key, value = next(iter(data.items())) - return schema.from_object(value, schema, name=key) - if not isinstance(data, (bytes, str, int, float, list)): - raise TypeError(f"Expected a bytes, str, int, float, but found {type(data).__name__}.") - if isinstance(data, list) and not all(isinstance(item, int) for item in data): - raise TypeError("Expected a list of integers.") # LongArrayNBT, IntArrayNBT - - data = cast(Union[bytes, str, int, float, List[int]], data) - return schema(data, name=name) + + # Case 3 : schema is a list or a tuple + # We need to check if every element in the schema has the same type + # but keep in mind that dict and list are also valid types, as long + # as there are only dicts, or only lists in the schema + if not isinstance(data, list): + raise TypeError("Expected a list, but found a different type.") + payload: list[NBTag] = [] + if len(schema) == 1: + # We have two cases here, either the schema supports an unknown number of elements of a single type ... + children_schema = schema[0] + for item in data: + # No name in list items + payload.append(NBTag.from_object(item, children_schema)) + return ListNBT(payload, name=name) + + # ... or the schema is a list of schemas + # Check if the schema and the data have the same length + if len(schema) != len(data): + raise ValueError(f"The schema and the data must have the same length. ({len(schema)=} != {len(data)=})") + if len(schema) == 0: + return ListNBT([], name=name) + # Check that the schema only has one type of elements + first_schema = schema[0] + # Dict/List case + if isinstance(first_schema, (list, dict)) and not all(isinstance(item, type(first_schema)) for item in schema): + raise TypeError(f"Expected a list of lists or dictionaries, but found a different type ({schema=}).") + # NBTag case + if isinstance(first_schema, type) and not all(item == first_schema for item in schema): + raise TypeError(f"The schema must contain a single type of elements. ({schema=})") + + for item, sub_schema in zip(data, schema): + payload.append(NBTag.from_object(item, sub_schema)) + return ListNBT(payload, name=name) def to_object( self, include_schema: bool = False, include_name: bool = False @@ -403,7 +433,7 @@ def to_object( - A tuple which includes one of the above and a schema describing the types of the original tag. """ if type(self) is EndNBT: - raise NotImplementedError("Cannot convert an EndNBT tag to a python object.") + return NotImplemented if type(self) in (CompoundNBT, ListNBT): raise TypeError( f"Use the `{type(self).__name__}.to_object()` method to convert the tag to a python object." @@ -423,7 +453,7 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: """Check equality between two NBT tags.""" if not isinstance(other, NBTag): - raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") + return NotImplemented if type(self) is not type(other): return False return self.name == other.name and self.payload == other.payload @@ -432,7 +462,7 @@ def __eq__(self, other: object) -> bool: def to_nbt(self, name: str = "") -> NBTag: """Convert the object to an NBT tag. - ..warning This is already an NBT tag, so it will modify the name of the tag and return itself. + .. warning:: This is already an NBT tag, so it will modify the name of the tag and return itself. """ self.name = name return self @@ -459,26 +489,11 @@ def __init__(self): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = False) -> None: - """Write the EndNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - """ self._write_header(buf, with_type=with_type, with_name=False) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> EndNBT: - """Read the EndNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag. Has no effect on the EndNBT tag. - - :return: The EndNBT tag. - """ _, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -488,19 +503,11 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) def to_object( self, include_schema: bool = False, include_name: bool = False ) -> PayloadType | Mapping[str, PayloadType]: - """Convert the EndNBT tag to a python object. - - :param include_schema: Whether to return a schema describing the types of the original tag. - :param include_name: Whether to include the name of the tag in the output. - - :return: None - """ return NotImplemented @property @override def value(self) -> PayloadType: - """Get the payload of the EndNBT tag in a python-friendly format.""" return NotImplemented @@ -512,12 +519,6 @@ class ByteNBT(NBTag): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the ByteNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - """ self._write_header(buf, with_type=with_type, with_name=with_name) if self.payload < -(1 << 7) or self.payload >= 1 << 7: raise OverflowError("Byte value out of range.") @@ -527,15 +528,6 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteNBT: - """Read the ByteNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The ByteNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -552,7 +544,6 @@ def __int__(self) -> int: @property @override def value(self) -> int: - """Get the integer value of the IntNBT tag.""" return self.payload @@ -563,33 +554,16 @@ class ShortNBT(ByteNBT): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the ShortNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The short value is written as a signed 16-bit integer in big-endian format. - """ self._write_header(buf, with_type=with_type, with_name=with_name) if self.payload < -(1 << 15) or self.payload >= 1 << 15: raise OverflowError("Short value out of range.") - buf.write(self.payload.to_bytes(2, "big", signed=True)) + buf.write_value(StructFormat.SHORT, self.payload) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ShortNBT: - """Read the ShortNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The ShortNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -597,7 +571,7 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) if buf.remaining < 2: raise IOError("Buffer does not contain enough data to read a short.") - return ShortNBT(int.from_bytes(buf.read(2), "big", signed=True), name=name) + return ShortNBT(buf.read_value(StructFormat.SHORT), name=name) class IntNBT(ByteNBT): @@ -607,34 +581,17 @@ class IntNBT(ByteNBT): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the IntNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The integer value is written as a signed 32-bit integer in big-endian format. - """ self._write_header(buf, with_type=with_type, with_name=with_name) if self.payload < -(1 << 31) or self.payload >= 1 << 31: raise OverflowError("Integer value out of range.") # No more messing around with the struct, we want 32 bits of data no matter what - buf.write(self.payload.to_bytes(4, "big", signed=True)) + buf.write_value(StructFormat.INT, self.payload) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntNBT: - """Read the IntNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The IntNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -642,7 +599,7 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) if buf.remaining < 4: raise IOError("Buffer does not contain enough data to read an int.") - return IntNBT(int.from_bytes(buf.read(4), "big", signed=True), name=name) + return IntNBT(buf.read_value(StructFormat.INT), name=name) class LongNBT(ByteNBT): @@ -652,34 +609,17 @@ class LongNBT(ByteNBT): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the LongNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The long value is written as a signed 64-bit integer in big-endian format. - """ self._write_header(buf, with_type=with_type, with_name=with_name) if self.payload < -(1 << 63) or self.payload >= 1 << 63: raise OverflowError("Long value out of range.") # No more messing around with the struct, we want 64 bits of data no matter what - buf.write(self.payload.to_bytes(8, "big", signed=True)) + buf.write_value(StructFormat.LONGLONG, self.payload) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongNBT: - """Read the LongNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The LongNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -687,8 +627,7 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) if buf.remaining < 8: raise IOError("Buffer does not contain enough data to read a long.") - payload = int.from_bytes(buf.read(8), "big", signed=True) - return LongNBT(payload, name=name) + return LongNBT(buf.read_value(StructFormat.LONGLONG), name=name) class FloatNBT(NBTag): @@ -700,29 +639,12 @@ class FloatNBT(NBTag): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the FloatNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The float value is written as a 32-bit floating-point value in big-endian format. - """ self._write_header(buf, with_type=with_type, with_name=with_name) buf.write_value(StructFormat.FLOAT, self.payload) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> FloatNBT: - """Read the FloatNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The FloatNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -736,30 +658,9 @@ def __float__(self) -> float: """Get the float value of the FloatNBT tag.""" return self.payload - @override - def __eq__(self, other: object) -> bool: - """Check equality between two FloatNBT tags. - - :param other: The other FloatNBT tag to compare to. - - :return: True if the tags are equal, False otherwise. - - :note: The float values are compared with a small epsilon (1e-6) to account for floating-point errors. - """ - if not isinstance(other, NBTag): - raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") - # Compare the float values with a small epsilon - if type(self) is not type(other): - return False - other.payload = cast(float, other.payload) - if self.name != other.name: - return False - return abs(self.payload - other.payload) < 1e-6 - @property @override def value(self) -> float: - """Get the float value of the FloatNBT tag.""" return self.payload @@ -770,29 +671,12 @@ class DoubleNBT(FloatNBT): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the DoubleNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The double value is written as a 64-bit floating-point value in big-endian format. - """ self._write_header(buf, with_type=with_type, with_name=with_name) buf.write_value(StructFormat.DOUBLE, self.payload) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> DoubleNBT: - """Read the DoubleNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The DoubleNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -812,14 +696,6 @@ class ByteArrayNBT(NBTag): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the ByteArrayNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The length of the byte array is written as a signed 32-bit integer in big-endian format. - """ self._write_header(buf, with_type=with_type, with_name=with_name) IntNBT(len(self.payload)).write_to(buf, with_type=False, with_name=False) buf.write(self.payload) @@ -827,15 +703,6 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteArrayNBT: - """Read the ByteArrayNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The ByteArrayNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -860,7 +727,6 @@ def __bytes__(self) -> bytes: @override def __repr__(self) -> str: - """Get a string representation of the ByteArrayNBT tag.""" if self.name: return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)})" if len(self.payload) < 8: @@ -870,7 +736,6 @@ def __repr__(self) -> str: @property @override def value(self) -> bytes: - """Get the bytes value of the ByteArrayNBT tag.""" return self.payload @@ -883,14 +748,6 @@ class StringNBT(NBTag): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the StringNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The length of the string is written as a signed 16-bit integer in big-endian format. - """ self._write_header(buf, with_type=with_type, with_name=with_name) if len(self.payload) > 32767: # Check the length of the string (can't generate strings that long in tests) @@ -903,15 +760,6 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> StringNBT: - """Read the StringNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The StringNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -933,13 +781,11 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) @override def __str__(self) -> str: - """Get the string value of the StringNBT tag.""" return self.payload @property @override def value(self) -> str: - """Get the string value of the StringNBT tag.""" return self.payload @@ -952,15 +798,6 @@ class ListNBT(NBTag): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the ListNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The tag type of the list is written as a single byte, followed by the length of the list as a signed - 32-bit integer in big-endian format. The tags in the list are then serialized one by one. - """ self._write_header(buf, with_type=with_type, with_name=with_name) if not self.payload: @@ -989,15 +826,6 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ListNBT: - """Read the ListNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The ListNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -1030,7 +858,6 @@ def __iter__(self) -> Iterator[NBTag]: @override def __repr__(self) -> str: - """Get a string representation of the ListNBT tag.""" if self.name: return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" if len(self.payload) < 8: @@ -1045,17 +872,6 @@ def to_object( | Mapping[str, list[PayloadType]] | tuple[list[PayloadType] | Mapping[str, list[PayloadType]], list[FromObjectSchema]] ): - """Convert the ListNBT tag to a python object. - - :param include_schema: Whether to return a schema describing the types of the original tag. - :param include_name: Whether to include the name of the tag in the output. - If the tag has no name, the name will be set to "". - - :return: Either : - - A list containing the payload of the tag. (default) - - A dictionary containing the name associated with a list containing the payload of the tag. - - A tuple which includes one of the above and a list of schemas describing the types of the original tag. - """ result = [tag.to_object() for tag in self.payload] result = cast(List[PayloadType], result) result = result if not include_name else {self.name: result} @@ -1085,7 +901,6 @@ def to_object( @property @override def value(self) -> list[PayloadType]: - """Get the payload of the ListNBT tag in a python-friendly format.""" return [tag.value for tag in self.payload] @@ -1098,15 +913,6 @@ class CompoundNBT(NBTag): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the CompoundNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. THis only affects the name of - the compound tag itself, not the names of the tags inside the compound. - - :note: The tags in the compound are serialized one by one, followed by an EndNBT tag. - """ self._write_header(buf, with_type=with_type, with_name=with_name) if not self.payload: EndNBT().write_to(buf, with_name=False, with_type=True) @@ -1130,15 +936,6 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> CompoundNBT: - """Read the CompoundNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The CompoundNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if _get_tag_type(cls) != tag_type: raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") @@ -1161,7 +958,6 @@ def __iter__(self): @override def __repr__(self) -> str: - """Get a string representation of the CompoundNBT tag.""" if self.name: return f"{type(self).__name__}[{self.name!r}]({dict(self)})" return f"{type(self).__name__}({dict(self)})" @@ -1177,17 +973,6 @@ def to_object( Mapping[str, FromObjectSchema], ] ): - """Convert the CompoundNBT tag to a python object. - - :param include_schema: Whether to return a schema describing the types of the original tag and its children. - :param include_name: Whether to include the name of the tag in the output. - If the tag has no name, the name will be set to "". - - :return: Either : - - A dictionary containing the payload of the tag. (default) - - A dictionary containing the name associated with a dictionary containing the payload of the tag. - - A tuple which includes one of the above and a dictionary of schemas describing the types of the original tag. - """ result = {tag.name: tag.to_object() for tag in self.payload} result = cast(Mapping[str, PayloadType], result) result = result if not include_name else {self.name: result} @@ -1210,12 +995,12 @@ def __eq__(self, other: object) -> bool: :return: True if the tags are equal, False otherwise. - :note: The order of the tags is not guaranteed, but the names of the tags must match. This function assumes + .. note:: The order of the tags is not guaranteed, but the names of the tags must match. This function assumes that there are no duplicate tags in the compound. """ # The order of the tags is not guaranteed if not isinstance(other, NBTag): - raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") + return NotImplemented if type(self) is not type(other): return False if self.name != other.name: @@ -1228,7 +1013,6 @@ def __eq__(self, other: object) -> bool: @property @override def value(self) -> dict[str, PayloadType]: - """Get the dictionary of tags in the CompoundNBT tag.""" return {tag.name: tag.value for tag in self.payload} @@ -1241,14 +1025,6 @@ class IntArrayNBT(NBTag): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the IntArrayNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The length of the integer array is written as a signed 32-bit integer in big-endian format. - """ self._write_header(buf, with_type=with_type, with_name=with_name) if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway @@ -1264,15 +1040,6 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntArrayNBT: - """Read the IntArrayNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The IntArrayNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if tag_type != NBTagType.INT_ARRAY: raise TypeError(f"Expected an INT_ARRAY tag, but found a different tag ({tag_type}).") @@ -1287,7 +1054,6 @@ def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) @override def __repr__(self) -> str: - """Get a string representation of the IntArrayNBT tag.""" if self.name: return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" if len(self.payload) < 8: @@ -1301,7 +1067,6 @@ def __iter__(self) -> Iterator[int]: @property @override def value(self) -> list[int]: - """Get the list of integers in the IntArrayNBT tag.""" return self.payload @@ -1312,14 +1077,6 @@ class LongArrayNBT(IntArrayNBT): @override def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: - """Write the LongArrayNBT tag to the buffer. - - :param buf: The buffer to write to. - :param with_type: Whether to include the type of the tag in the serialization. - :param with_name: Whether to include the name of the tag in the serialization. - - :note: The length of the long array is written as a signed 32-bit integer in big-endian format. - """ self._write_header(buf, with_type=with_type, with_name=with_name) if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway @@ -1335,15 +1092,6 @@ def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) @override @classmethod def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongArrayNBT: - """Read the LongArrayNBT tag from the buffer. - - :param buf: The buffer to read from. - :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class - will be used. - :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. - - :return: The LongArrayNBT tag. - """ name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) if tag_type != NBTagType.LONG_ARRAY: raise TypeError(f"Expected a LONG_ARRAY tag, but found a different tag ({tag_type}).") diff --git a/tests/mcproto/types/test_nbt.py b/tests/mcproto/types/test_nbt.py index f68f6e16..4ac0191c 100644 --- a/tests/mcproto/types/test_nbt.py +++ b/tests/mcproto/types/test_nbt.py @@ -2,7 +2,6 @@ import struct from typing import Any, Dict, List, cast -from typing_extensions import override import pytest @@ -24,7 +23,6 @@ PayloadType, ShortNBT, StringNBT, - NBTagConvertible, ) # region EndNBT @@ -79,11 +77,11 @@ def test_serialize_deserialize_end(): (LongNBT, -1, bytearray.fromhex("04 FF FF FF FF FF FF FF FF")), (LongNBT, 12, bytearray.fromhex("04 00 00 00 00 00 00 00 0C")), (FloatNBT, 1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 1.0))), - (FloatNBT, 3.14, bytearray.fromhex("05") + bytes(struct.pack(">f", 3.14))), + (FloatNBT, 0.25, bytearray.fromhex("05") + bytes(struct.pack(">f", 0.25))), (FloatNBT, -1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", -1.0))), (FloatNBT, 12.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 12.0))), (DoubleNBT, 1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 1.0))), - (DoubleNBT, 3.14, bytearray.fromhex("06") + bytes(struct.pack(">d", 3.14))), + (DoubleNBT, 0.25, bytearray.fromhex("06") + bytes(struct.pack(">d", 0.25))), (DoubleNBT, -1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", -1.0))), (DoubleNBT, 12.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 12.0))), (ByteArrayNBT, b"", bytearray.fromhex("07 00 00 00 00")), @@ -379,9 +377,9 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType ), ( FloatNBT, - 3.14, + 0.25, "a", - bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 3.14)), + bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 0.25)), ), ( FloatNBT, @@ -403,9 +401,9 @@ def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType ), ( DoubleNBT, - 3.14, + 0.25, "a", - bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 3.14)), + bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 0.25)), ), ( DoubleNBT, @@ -1024,7 +1022,7 @@ def test_nbt_bigfile(): Slightly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. Source data: https://wiki.vg/NBT#Example. """ - data = "0a00054c6576656c0400086c6f6e67546573747fffffffffffffff02000973686f7274546573747fff08000a737472696e6754657374002948454c4c4f20574f524c4420544849532049532041205445535420535452494e4720c385c384c39621050009666c6f6174546573743eff1832030007696e74546573747fffffff0a00146e657374656420636f6d706f756e6420746573740a000368616d0800046e616d65000648616d70757305000576616c75653f400000000a00036567670800046e616d6500074567676265727405000576616c75653f00000000000c000f6c6973745465737420286c6f6e672900000005000000000000000b000000000000000c000000000000000d000000000000000e7fffffffffffffff0b000e6c697374546573742028696e7429000000047fffffff7ffffffe7ffffffd7ffffffc0900136c697374546573742028636f6d706f756e64290a000000020800046e616d65000f436f6d706f756e642074616720233004000a637265617465642d6f6e000001265237d58d000800046e616d65000f436f6d706f756e642074616720233104000a637265617465642d6f6e000001265237d58d0001000862797465546573747f07006562797465417272617954657374202874686520666972737420313030302076616c756573206f6620286e2a6e2a3235352b6e2a3729253130302c207374617274696e672077697468206e3d302028302c2036322c2033342c2031362c20382c202e2e2e2929000003e8003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a063005000a646f75626c65546573743efc7b5e00" # noqa: E501 + data = "0a00054c6576656c0400086c6f6e67546573747fffffffffffffff02000973686f7274546573747fff08000a737472696e6754657374002948454c4c4f20574f524c4420544849532049532041205445535420535452494e4720c385c384c39621050009666c6f6174546573743eff1832030007696e74546573747fffffff0a00146e657374656420636f6d706f756e6420746573740a000368616d0800046e616d65000648616d70757305000576616c75653f400000000a00036567670800046e616d6500074567676265727405000576616c75653f00000000000c000f6c6973745465737420286c6f6e672900000005000000000000000b000000000000000c000000000000000d000000000000000e7fffffffffffffff0b000e6c697374546573742028696e7429000000047fffffff7ffffffe7ffffffd7ffffffc0900136c697374546573742028636f6d706f756e64290a000000020800046e616d65000f436f6d706f756e642074616720233004000a637265617465642d6f6e000001265237d58d000800046e616d65000f436f6d706f756e642074616720233104000a637265617465642d6f6e000001265237d58d0001000862797465546573747f07006562797465417272617954657374202874686520666972737420313030302076616c756573206f6620286e2a6e2a3235352b6e2a3729253130302c207374617274696e672077697468206e3d302028302c2036322c2033342c2031362c20382c202e2e2e2929000003e8003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a063005000a646f75626c65546573743efc000000" # noqa: E501 data = bytes.fromhex(data) buffer = Buffer(data) @@ -1047,7 +1045,7 @@ def test_nbt_bigfile(): "byteTest": 127, "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, " "starting with n=0 (0, 62, 34, 16, 8, ...))": bytes((n * n * 255 + n * 7) % 100 for n in range(1000)), - "doubleTest": 0.4931287132182315, + "doubleTest": 0.4921875, } expected_schema = { "longTest": LongNBT, @@ -1117,23 +1115,16 @@ def check_equality(self: object, other: object) -> bool: # region Edge cases -def test_from_object_lst_not_same_type(): - """Test from_object with a list that does not have the same type.""" - with pytest.raises(TypeError): - NBTag.from_object([0, "test"], [IntNBT, StringNBT]) - - def test_from_object_morecases(): """Test from_object with more edge cases.""" - class CustomType(NBTagConvertible): - @override + class CustomType: def to_nbt(self, name: str = "") -> NBTag: return ByteArrayNBT(b"CustomType", name) assert NBTag.from_object( { - "nbtag": ByteNBT(0), # ByteNBT + "number": ByteNBT(0), # ByteNBT "bytearray": b"test", # Conversion from bytes "empty_list": [], # Empty list with type EndNBT "empty_compound": {}, # Empty compound @@ -1144,7 +1135,7 @@ def to_nbt(self, name: str = "") -> NBTag: ], }, { - "nbtag": ByteNBT, + "number": ByteNBT, "bytearray": ByteArrayNBT, "empty_list": [], "empty_compound": {}, @@ -1157,7 +1148,7 @@ def to_nbt(self, name: str = "") -> NBTag: ByteArrayNBT(b"test", "bytearray"), ByteArrayNBT(b"CustomType", "custom"), ListNBT([], "empty_list"), - ByteNBT(0, "nbtag"), + ByteNBT(0, "number"), ListNBT( [ListNBT([IntNBT(0), IntNBT(1), IntNBT(2)]), ListNBT([ShortNBT(3), ShortNBT(4), ShortNBT(5)])], "recursive_list", @@ -1166,14 +1157,8 @@ def to_nbt(self, name: str = "") -> NBTag: ) compound = CompoundNBT.from_object( - { - "test": ByteNBT(0), - "test2": IntNBT(0), - }, - { - "test": ByteNBT, - "test2": IntNBT, - }, + {"test": 0, "test2": 0}, + {"test": ByteNBT, "test2": IntNBT}, name="compound", ) @@ -1204,7 +1189,7 @@ def to_nbt(self, name: str = "") -> NBTag: {"test": [1, 0]}, {"test": [ByteNBT, IntNBT]}, TypeError, - "Expected a list of lists or dictionaries, but found a different type.", + "The schema must contain a single type of elements. .*", ), # Schema and data have different lengths ( @@ -1213,8 +1198,7 @@ def to_nbt(self, name: str = "") -> NBTag: ValueError, "The schema and the data must have the same length.", ), - # schema empty, data is not - ([1], [], ValueError, "The schema is empty, but the data is not."), + ([1], [], ValueError, "The schema and the data must have the same length."), # Schema is a dict, data is not (["test"], {"test": ByteNBT}, TypeError, "Expected a dictionary, but found a different type."), # Schema is not a dict, list or subclass of NBTagConvertible @@ -1222,20 +1206,34 @@ def to_nbt(self, name: str = "") -> NBTag: ["test"], "test", TypeError, - "The schema must be a list, dict or a subclass of either NBTag or NBTagConvertible.", + "The schema must be a list, dict, a subclass of NBTag or an object with a `to_nbt` method.", + ), + # Schema contains a mix of dict and list + ( + [{"test": 0}, [1, 2, 3]], + [{"test": ByteNBT}, [IntNBT]], + TypeError, + "Expected a list of lists or dictionaries, but found a different type", + ), + # Schema contains multiple types + ( + [[0], [-1]], + [IntArrayNBT, LongArrayNBT], + TypeError, + "The schema must contain a single type of elements.", ), # Schema contains CompoundNBT or ListNBT instead of a dict or list ( {"test": 0}, CompoundNBT, ValueError, - "The schema must specify the type of the elements in CompoundNBT and ListNBT tags.", + "Use a list or a dictionary in the schema to create a CompoundNBT or a ListNBT.", ), ( ["test"], ListNBT, ValueError, - "The schema must specify the type of the elements in CompoundNBT and ListNBT tags.", + "Use a list or a dictionary in the schema to create a CompoundNBT or a ListNBT.", ), # The schema specifies a type, but the data is a dict with more than one key ( @@ -1266,11 +1264,27 @@ def test_from_object_error(data: Any, schema: Any, error: type[Exception], error NBTag.from_object(data, schema) +def test_from_object_more_errors(): + """Test from_object with more edge cases.""" + # Redefine the name of the tag + schema = ByteNBT + data = {"test": 0} + with pytest.raises(ValueError): + NBTag.from_object(data, schema, name="othername") + + class CustomType: + def to_nbt(self, name: str = "") -> NBTag: + return ByteArrayNBT(b"CustomType", name) + + # Wrong data type + with pytest.raises(TypeError): + NBTag.from_object(0, CustomType) + + def test_to_object_morecases(): """Test to_object with more edge cases.""" - class CustomType(NBTagConvertible): - @override + class CustomType: def to_nbt(self, name: str = "") -> NBTag: return ByteArrayNBT(b"CustomType", name)