diff --git a/changes/257.feature.md b/changes/257.feature.md new file mode 100644 index 00000000..01125394 --- /dev/null +++ b/changes/257.feature.md @@ -0,0 +1,9 @@ +- Added the `NBTag` to deal with NBT data: + - The `NBTag` class is the base class for all NBT tags and provides the basic functionality to serialize and deserialize NBT data from and to a `Buffer` object. + - The classes `EndNBT`, `ByteNBT`, `ShortNBT`, `IntNBT`, `LongNBT`, `FloatNBT`, `DoubleNBT`, `ByteArrayNBT`, `StringNBT`, `ListNBT`, `CompoundNBT`, `IntArrayNBT`and `LongArrayNBT` were added and correspond to the NBT types described in the [NBT specification](https://wiki.vg/NBT#Specification). + - NBT tags can be created using the `NBTag.from_object()` method, which automatically selects the correct tag type based on the object's type and works recursively for lists and dictionaries. + - The `NBTag.to_object()` method can be used to convert an NBT tag back to a Python object. + - The `NBTag.serialize()` can be used to serialize an NBT tag to a new `Buffer` object. + - The `NBTag.deserialize(buffer)` can be used to deserialize an NBT tag from a `Buffer` object. + - If the buffer already exists, the `NBTag.write_to(buffer, with_type=True, with_name=True)` method can be used to write the NBT tag to the buffer (and in that case with the type and name in the right format). + - The `NBTag.read_from(buffer, with_type=True, with_name=True)` method can be used to read an NBT tag from the buffer (and in that case with the type and name in the right format). \ No newline at end of file diff --git a/mcproto/types/nbt.py b/mcproto/types/nbt.py new file mode 100644 index 00000000..e8a057bb --- /dev/null +++ b/mcproto/types/nbt.py @@ -0,0 +1,1315 @@ +from __future__ import annotations + +import warnings +from abc import ABCMeta +from enum import IntEnum +from typing import ClassVar, List, Mapping, Union, cast + +from typing_extensions import TypeAlias + +from mcproto.buffer import Buffer +from mcproto.protocol.base_io import StructFormat +from mcproto.types.abc import MCType + +__all__ = [ + "NBTagType", + "NBTag", + "EndNBT", + "ByteNBT", + "ShortNBT", + "IntNBT", + "LongNBT", + "FloatNBT", + "DoubleNBT", + "ByteArrayNBT", + "StringNBT", + "ListNBT", + "CompoundNBT", + "IntArrayNBT", + "LongArrayNBT", +] + + +# region NBT Specification +""" +Source : https://web.archive.org/web/20110723210920/http://www.minecraft.net/docs/NBT.txt + +Named Binary Tag specification + +NBT (Named Binary Tag) is a tag based binary format designed to carry large amounts of binary data with smaller amounts +of additional data. +An NBT file consists of a single GZIPped Named Tag of type TAG_Compound. + +A Named Tag has the following format: + + byte tagType + TAG_String name + [payload] + +The tagType is a single byte defining the contents of the payload of the tag. + +The name is a descriptive name, and can be anything (eg "cat", "banana", "Hello World!"). It has nothing to do with the +tagType. +The purpose for this name is to name tags so parsing is easier and can be made to only look for certain recognized tag +names. +Exception: If tagType is TAG_End, the name is skipped and assumed to be "". + +The [payload] varies by tagType. + +Note that ONLY Named Tags carry the name and tagType data. Explicitly identified Tags (such as TAG_String above) only +contains the payload. + + +The tag types and respective payloads are: + + TYPE: 0 NAME: TAG_End + Payload: None. + Note: This tag is used to mark the end of a list. + Cannot be named! If type 0 appears where a Named Tag is expected, the name is assumed to be "". + (In other words, this Tag is always just a single 0 byte when named, and nothing in all other cases) + + TYPE: 1 NAME: TAG_Byte + Payload: A single signed byte (8 bits) + + TYPE: 2 NAME: TAG_Short + Payload: A signed short (16 bits, big endian) + + TYPE: 3 NAME: TAG_Int + Payload: A signed short (32 bits, big endian) + + TYPE: 4 NAME: TAG_Long + Payload: A signed long (64 bits, big endian) + + TYPE: 5 NAME: TAG_Float + Payload: A floating point value (32 bits, big endian, IEEE 754-2008, binary32) + + TYPE: 6 NAME: TAG_Double + Payload: A floating point value (64 bits, big endian, IEEE 754-2008, binary64) + + TYPE: 7 NAME: TAG_Byte_Array + Payload: TAG_Int length + An array of bytes of unspecified format. The length of this array is bytes + + TYPE: 8 NAME: TAG_String + Payload: TAG_Short length + An array of bytes defining a string in UTF-8 format. The length of this array is bytes + + TYPE: 9 NAME: TAG_List + Payload: TAG_Byte tagId + TAG_Int length + A sequential list of Tags (not Named Tags), of type . The length of this array is Tags + Notes: All tags share the same type. + + TYPE: 10 NAME: TAG_Compound + Payload: A sequential list of Named Tags. This array keeps going until a TAG_End is found. + TAG_End end + Notes: If there's a nested TAG_Compound within this tag, that one will also have a TAG_End, so simply reading + until the next TAG_End will not work. + The names of the named tags have to be unique within each TAG_Compound + The order of the tags is not guaranteed. + + + // NEW TAGS + TYPE: 11 NAME: TAG_Int_Array + Payload: TAG_Int length + An array of integers. The length of this array is integers + + TYPE: 12 NAME: TAG_Long_Array + Payload: TAG_Int length + An array of longs. The length of this array is longs + + +""" +# endregion +# region NBT base classes/types + + +class NBTagType(IntEnum): + """Types of NBT tags.""" + + END = 0 + BYTE = 1 + SHORT = 2 + INT = 3 + LONG = 4 + FLOAT = 5 + DOUBLE = 6 + BYTE_ARRAY = 7 + STRING = 8 + LIST = 9 + COMPOUND = 10 + INT_ARRAY = 11 + LONG_ARRAY = 12 + + +PayloadType: TypeAlias = Union[ + int, + float, + bytearray, + bytes, + str, + List["PayloadType"], + Mapping[str, "PayloadType"], + List[int], + "NBTag", + List["NBTag"], +] + + +class _MetaNBTag(ABCMeta): + """Metaclass for NBT tags.""" + + TYPE: NBTagType = NBTagType.COMPOUND + + def __new__(cls, name: str, bases: tuple[type], namespace: dict, **kwargs): + new_cls: NBTag = super().__new__(cls, name, bases, namespace) # type: ignore + if name != "NBTag": + NBTag.ASSOCIATED_TYPES[new_cls.TYPE] = new_cls # type: ignore + return new_cls + + +class NBTag(MCType, metaclass=_MetaNBTag): + """Base class for NBT tags.""" + + __slots__ = ("name", "payload") + + TYPE: ClassVar[NBTagType] = NBTagType.COMPOUND + + ASSOCIATED_TYPES: ClassVar[dict[NBTagType, type[NBTag]]] = {} + + def __init__(self, payload: PayloadType, name: str = ""): + if self.__class__ == NBTag: + raise TypeError("Cannot instantiate an NBTag object directly, use a subclass instead.") + self.name = name + self.payload = payload + + def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: + """Serialize the NBT tag to a buffer. + + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: These parameters only control the first level of serialization. + :return: The buffer containing the serialized NBT tag. + """ + buf = Buffer() + self.write_to(buf, with_name=with_name, with_type=with_type) + return buf + + def _write_header(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> bool: + if with_type: + buf.write_value(StructFormat.BYTE, self.TYPE.value) + if self.TYPE == NBTagType.END: + return False + if with_name: + if not self.name: + raise ValueError("Named tags must have a name.") + StringNBT(self.name).write_to(buf, with_type=False, with_name=False) + return True + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the NBT tag to the buffer.""" + ... + + @classmethod + def deserialize(cls, buf: Buffer, with_name: bool = True, with_type: bool = True) -> NBTag: + """Deserialize the NBT tag. + + :param buf: The buffer to read from. + :param with_name: Whether to read the name of the tag. If set to False, the tag will have the name "". + :param with_type: Whether to read the type of the tag. + If set to True, the tag will be read from the buffer and + the return type will be inferred from the tag type. + If set to False and called from a subclass, the tag type will be inferred from the subclass. + If set to False and called from the base class, the tag type will be TAG_Compound. + + If with_type is set to False, the buffer must not start with the tag type byte. + + :return: The deserialized NBT tag. + """ + name, tag_type = cls._read_header(buf, with_name=with_name, read_type=with_type) + + tag_class = NBTag.ASSOCIATED_TYPES[tag_type] + if cls not in (NBTag, tag_class): + raise TypeError(f"Expected a {cls.__name__} tag, but found a different tag ({tag_class.__name__}).") + + tag = tag_class.read_from(buf, with_type=False, with_name=False) + tag.name = name + return tag + + @classmethod + def _read_header(cls, buf: Buffer, read_type: bool = True, with_name: bool = True) -> tuple[str, NBTagType]: + """Read the header of the NBT tag. + + :param buf: The buffer to read from. + :param read_type: Whether to read the type of the tag from the buffer. + :param with_name: Whether to read the name of the tag. If set to False, the tag will have the name "". + + :return: A tuple containing the name and the tag type. + + + :note: It is possible that this function reads nothing from the buffer if both with_name and read_type are set + to False. + """ + tag_type: NBTagType = cls.TYPE # default value + if read_type: + try: + tag_type = NBTagType(buf.read_value(StructFormat.BYTE)) + except OSError: + raise IOError("Buffer is empty.") from None + except ValueError: + raise TypeError("Invalid tag type.") from None + + if tag_type == NBTagType.END: + return "", tag_type + + name = StringNBT.read_from(buf, with_type=False, with_name=False).value if with_name else "" + + return name, tag_type + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> NBTag: + """Read the NBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag. + If set to True, the tag will be read from the buffer and + the return type will be inferred from the tag type. + If set to False and called from a subclass, the tag type will be inferred from the subclass. + If set to False and called from the base class, the tag type will be TAG_Compound. + :param with_name: Whether to read the name of the tag. If set to False, the tag will have the name "". + + :return: The NBT tag. + """ + return cls.deserialize(buf, with_name=with_name, with_type=with_type) + + @staticmethod + def from_object(data: object, /, name: str = "", *, use_int_array: bool = True) -> NBTag: # noqa: PLR0911,PLR0912 + """Create an NBT tag from an arbitrary (compatible) Python object. + + :param data: The object to convert to an NBT tag. + :param use_int_array: Whether to use IntArrayNBT and LongArrayNBT for lists of integers. + If set to False, all lists of integers will be considered as ListNBT. + :param name: The name of the resulting tag. Used for recursive calls. + + :return: The NBT tag representing the object. + + :note: The function will attempt to convert the object to an NBT tag in the following way: + - If the object is a dictionary with a single key, the key will be used as the name of the tag. + - If the object is an integer, it will be converted to a ByteNBT, ShortNBT, IntNBT, or LongNBT tag + depending on the value. + - If the object is a list, it will be converted to a ListNBT tag. + - If the object is a dictionary, it will be converted to a CompoundNBT tag. + - If the object is a string, it will be converted to a StringNBT tag. + - If the object is a float, it will be converted to a FloatNBT tag. + - If the object can be serialized to bytes, it will be converted to a ByteArrayNBT tag. + + + - If you want an object to be serialized in a specific way, you can implement: + + ```python + def to_nbt(self, name: str = "") -> NBTag: + ... + ``` + """ + if hasattr(data, "to_nbt"): # For objects that can be converted to NBT + return data.to_nbt(name=name) # type: ignore + + if isinstance(data, int): + if -(1 << 7) <= data < 1 << 7: + return ByteNBT(data, name=name) + if -(1 << 15) <= data < 1 << 15: + return ShortNBT(data, name=name) + if -(1 << 31) <= data < 1 << 31: + return IntNBT(data, name=name) + if -(1 << 63) <= data < 1 << 63: + return LongNBT(data, name=name) + raise ValueError(f"Integer {data} is out of range.") + if isinstance(data, float): + return FloatNBT(data, name=name) + if isinstance(data, str): + return StringNBT(data, name=name) + if isinstance(data, (bytearray, bytes)): + if isinstance(data, bytearray): + data = bytes(data) + return ByteArrayNBT(data, name=name) + if isinstance(data, list): + if not data: + # Type END is used to mark an empty list + return ListNBT([], name=name) + first_type = type(data[0]) + if any(type(item) != first_type for item in data): + raise TypeError("All items in a list must be of the same type.") + + if issubclass(first_type, int) and use_int_array: + # Check the range of the integers in the list + use_int = all(-(1 << 31) <= item < 1 << 31 for item in data) + use_long = all(-(1 << 63) <= item < 1 << 63 for item in data) + if use_int: + return IntArrayNBT(data, name=name) + if not use_long: # Too big to fit in a long, won't fit in a List of Longs either + raise ValueError("Integer list contains values out of range.") + return LongArrayNBT(data, name=name) + return ListNBT([NBTag.from_object(item, use_int_array=use_int_array) for item in data], name=name) + if isinstance(data, dict): + if len(data) == 0: + return CompoundNBT([], name=name) + if len(data) == 1 and name == "": + key, value = next(iter(data.items())) + return NBTag.from_object(value, name=key, use_int_array=use_int_array) + payload = [] + for key, value in data.items(): + tag = NBTag.from_object(value, name=key, use_int_array=use_int_array) + payload.append(tag) + return CompoundNBT(payload, name) + if data is None: + warnings.warn("Converting None to an END tag.", stacklevel=2) + return EndNBT() # Should not be used + + try: + # Check if the object can be converted to bytes + return ByteArrayNBT(bytes(data), name=name) # type: ignore + except (TypeError, ValueError): + pass + raise TypeError(f"Cannot convert object of type {type(data)} to an NBT tag.") + + def to_object(self) -> Mapping[str, PayloadType] | PayloadType: + """Convert the NBT payload to a dictionary.""" + return CompoundNBT(self.payload).to_object() # allow NBTag.to_object to act as a dict + + def __getitem__(self, key: str | int) -> PayloadType: + """Get a tag from the list or compound tag.""" + if self.TYPE not in (NBTagType.LIST, NBTagType.COMPOUND, NBTagType.INT_ARRAY, NBTagType.LONG_ARRAY): + raise TypeError(f"Cannot get a tag by index from a non-LIST or non-COMPOUND tag ({self.TYPE}).") + + if not isinstance(self.payload, list): + raise AttributeError( + f"The payload of the tag is not a list ({self.TYPE}).\n" + "Check that the initialization of the tag is correct." + ) + if not isinstance(key, (str, int)): # type: ignore + raise TypeError("Key must be a string or an integer.") + + if isinstance(key, str): + if self.TYPE != NBTagType.COMPOUND: + raise TypeError(f"Cannot get a tag by name from a non-COMPOUND tag ({self.TYPE}).") + if not all(isinstance(tag, NBTag) for tag in self.payload): + raise AttributeError("The payload of the tag is not a list of NBTag objects.") + for tag in self.payload: + tag = cast(NBTag, tag) + if tag.name == key: + return tag + raise KeyError(f"No tag with the name {key!r} found.") + + # Key is an integer + if key < -len(self.payload) or key >= len(self.payload): + raise IndexError(f"Index {key} out of range.") + return self.payload[key] + + def __repr__(self) -> str: + if self.name: + return f"{self.__class__.__name__}[{self.name!r}]({self.payload!r})" + return f"{self.__class__.__name__}({self.payload!r})" + + def __eq__(self, other: object) -> bool: + """Check equality between two NBT tags.""" + if not isinstance(other, NBTag): + raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") + return self.name == other.name and self.TYPE == other.TYPE and self.payload == other.payload + + def to_nbt(self, name: str = "") -> NBTag: + """Convert the object to an NBT tag. + + ..warning This is already an NBT tag, so it will modify the name of the tag and return itself. + """ + self.name = name + return self + + @property + def value(self) -> PayloadType: + """Get the payload of the NBT tag in a python-friendly format.""" + obj = self.to_object() + if isinstance(obj, dict) and self.name: + return obj[self.name] + return obj + + +# endregion +# region NBT tags types + + +class EndNBT(NBTag): + """Sentinel tag used to mark the end of a TAG_Compound.""" + + TYPE = NBTagType.END + __slots__ = () + + def __init__(self): + """Create a new EndNBT tag.""" + super().__init__(0, name="") + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the EndNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> EndNBT: + """Read the EndNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag. Has no effect on the EndNBT tag. + + :return: The EndNBT tag. + """ + _, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + return EndNBT() + + def to_object(self) -> Mapping[str, PayloadType]: + """Convert the EndNBT tag to a python object. + + :return: An empty dictionary. + """ + return {} + + +class ByteNBT(NBTag): + """NBT tag representing a single byte value, represented as a signed 8-bit integer.""" + + TYPE = NBTagType.BYTE + + __slots__ = () + payload: int + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the ByteNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + if self.payload < -(1 << 7) or self.payload >= 1 << 7: + raise OverflowError("Byte value out of range.") + + buf.write_value(StructFormat.BYTE, self.payload) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteNBT: + """Read the ByteNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The ByteNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 1: + raise IOError("Buffer does not contain enough data to read a byte. (Empty buffer)") + + return ByteNBT(buf.read_value(StructFormat.BYTE), name=name) + + def __int__(self) -> int: + """Get the integer value of the ByteNBT tag.""" + return self.payload + + def to_object(self) -> Mapping[str, int] | int: + """Convert the ByteNBT tag to a python object. + + :return: A dictionary containing the name and the integer value of the tag. If the tag has no name, the value + will be returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + @property + def value(self) -> int: + """Get the integer value of the IntNBT tag.""" + return self.payload + + +class ShortNBT(ByteNBT): + """NBT tag representing a short value, represented as a signed 16-bit integer.""" + + TYPE = NBTagType.SHORT + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the ShortNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The short value is written as a signed 16-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if self.payload < -(1 << 15) or self.payload >= 1 << 15: + raise OverflowError("Short value out of range.") + + buf.write(self.payload.to_bytes(2, "big", signed=True)) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ShortNBT: + """Read the ShortNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The ShortNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 2: + raise IOError("Buffer does not contain enough data to read a short.") + + return ShortNBT(int.from_bytes(buf.read(2), "big", signed=True), name=name) + + +class IntNBT(ByteNBT): + """NBT tag representing an integer value, represented as a signed 32-bit integer.""" + + TYPE = NBTagType.INT + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the IntNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The integer value is written as a signed 32-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if self.payload < -(1 << 31) or self.payload >= 1 << 31: + raise OverflowError("Integer value out of range.") + + # No more messing around with the struct, we want 32 bits of data no matter what + buf.write(self.payload.to_bytes(4, "big", signed=True)) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntNBT: + """Read the IntNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The IntNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 4: + raise IOError("Buffer does not contain enough data to read an int.") + + return IntNBT(int.from_bytes(buf.read(4), "big", signed=True), name=name) + + +class LongNBT(ByteNBT): + """NBT tag representing a long value, represented as a signed 64-bit integer.""" + + TYPE = NBTagType.LONG + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the LongNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The long value is written as a signed 64-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if self.payload < -(1 << 63) or self.payload >= 1 << 63: + raise OverflowError("Long value out of range.") + + # No more messing around with the struct, we want 64 bits of data no matter what + buf.write(self.payload.to_bytes(8, "big", signed=True)) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongNBT: + """Read the LongNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The LongNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 8: + raise IOError("Buffer does not contain enough data to read a long.") + + payload = int.from_bytes(buf.read(8), "big", signed=True) + return LongNBT(payload, name=name) + + +class FloatNBT(NBTag): + """NBT tag representing a floating-point value, represented as a 32-bit IEEE 754-2008 binary32 value.""" + + TYPE = NBTagType.FLOAT + + payload: float + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the FloatNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The float value is written as a 32-bit floating-point value in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + buf.write_value(StructFormat.FLOAT, self.payload) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> FloatNBT: + """Read the FloatNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The FloatNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 4: + raise IOError("Buffer does not contain enough data to read a float.") + + return FloatNBT(buf.read_value(StructFormat.FLOAT), name=name) + + def __float__(self) -> float: + """Get the float value of the FloatNBT tag.""" + return self.payload + + def to_object(self) -> Mapping[str, float] | float: + """Convert the FloatNBT tag to a python object. + + :return: A dictionary containing the name and the float value of the tag. If the tag has no name, the value + will be returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + def __eq__(self, other: object) -> bool: + """Check equality between two FloatNBT tags. + + :param other: The other FloatNBT tag to compare to. + + :return: True if the tags are equal, False otherwise. + + :note: The float values are compared with a small epsilon (1e-6) to account for floating-point errors. + """ + if not isinstance(other, NBTag): + raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") + # Compare the float values with a small epsilon + if not (self.name == other.name and self.TYPE == other.TYPE): + return False + if not isinstance(other, self.__class__): # pragma: no cover + return False # Should not happen if nobody messes with the TYPE attribute + + return abs(self.payload - other.payload) < 1e-6 + + @property + def value(self) -> float: + """Get the float value of the FloatNBT tag.""" + return self.payload + + +class DoubleNBT(FloatNBT): + """NBT tag representing a double-precision floating-point value, represented as a 64-bit IEEE 754-2008 binary64.""" + + TYPE = NBTagType.DOUBLE + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the DoubleNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The double value is written as a 64-bit floating-point value in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + buf.write_value(StructFormat.DOUBLE, self.payload) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> DoubleNBT: + """Read the DoubleNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The DoubleNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 8: + raise IOError("Buffer does not contain enough data to read a double.") + + return DoubleNBT(buf.read_value(StructFormat.DOUBLE), name=name) + + +class ByteArrayNBT(NBTag): + """NBT tag representing an array of bytes. The length of the array is stored as a signed 32-bit integer.""" + + TYPE = NBTagType.BYTE_ARRAY + + __slots__ = () + + payload: bytearray + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the ByteArrayNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The length of the byte array is written as a signed 32-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + IntNBT(len(self.payload)).write_to(buf, with_type=False, with_name=False) + buf.write(self.payload) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteArrayNBT: + """Read the ByteArrayNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The ByteArrayNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + try: + length = IntNBT.read_from(buf, with_type=False, with_name=False).value + except IOError: + raise IOError("Buffer does not contain enough data to read a byte array.") from None + + if length < 0: + raise ValueError("Invalid byte array length.") + + if buf.remaining < length: + raise IOError( + f"Buffer does not contain enough data to read the byte array ({buf.remaining} < {length} bytes)." + ) + + return ByteArrayNBT(buf.read(length), name=name) + + def __bytes__(self) -> bytes: + """Get the bytes value of the ByteArrayNBT tag.""" + return self.payload + + def to_object(self) -> Mapping[str, bytearray] | bytearray: + """Convert the ByteArrayNBT tag to a python object. + + :return: A dictionary containing the name and the byte array value of the tag. If the tag has no name, the + value will be returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + def __repr__(self) -> str: + """Get a string representation of the ByteArrayNBT tag.""" + if self.name: + return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)})" + if len(self.payload) < 8: + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{self.__class__.__name__}(length={len(self.payload)}, {bytes(self.payload[:7])!r}...)" + + @property + def value(self) -> bytearray: + """Get the bytes value of the ByteArrayNBT tag.""" + return self.payload + + +class StringNBT(NBTag): + """NBT tag representing an UTF-8 string value. The length of the string is stored as a signed 16-bit integer.""" + + TYPE = NBTagType.STRING + + __slots__ = () + + payload: str + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the StringNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The length of the string is written as a signed 16-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + if len(self.payload) > 32767: + # Check the length of the string (can't generate strings that long in tests) + raise ValueError("Maximum character limit for writing strings is 32767 characters.") # pragma: no cover + + data = bytearray(self.payload, "utf-8") + ShortNBT(len(data)).write_to(buf, with_type=False, with_name=False) + buf.write(data) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> StringNBT: + """Read the StringNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The StringNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + try: + length = ShortNBT.read_from(buf, with_type=False, with_name=False).value + except IOError: + raise IOError("Buffer does not contain enough data to read a string.") from None + + if length < 0: + raise ValueError("Invalid string length.") + + if buf.remaining < length: + raise IOError("Buffer does not contain enough data to read the string.") + data = buf.read(length) + try: + return StringNBT(data.decode("utf-8"), name=name) + except UnicodeDecodeError: + raise # We want to know it + + def __str__(self) -> str: + """Get the string value of the StringNBT tag.""" + return self.payload + + def to_object(self) -> Mapping[str, str] | str: + """Convert the StringNBT tag to a python object. + + :return: A dictionary containing the name and the string value of the tag. If the tag has no name, the value + will be returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + @property + def value(self) -> str: + """Get the string value of the StringNBT tag.""" + return self.payload + + +class ListNBT(NBTag): + """NBT tag representing a list of tags. All tags in the list must be of the same type.""" + + TYPE = NBTagType.LIST + + __slots__ = () + + payload: list[NBTag] + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the ListNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The tag type of the list is written as a single byte, followed by the length of the list as a signed + 32-bit integer in big-endian format. The tags in the list are then serialized one by one. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if not self.payload: + # Set the tag type to TAG_End if the list is empty + EndNBT().write_to(buf, with_name=False) + IntNBT(0).write_to(buf, with_name=False, with_type=False) + return + + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway + raise ValueError( + f"All items in a list must be NBTags. Got {self.payload!r}.\nUse NBTag.from_object() to convert " + "objects to tags first." + ) + + tag_type = self.payload[0].TYPE + ByteNBT(tag_type).write_to(buf, with_name=False, with_type=False) + IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + for tag in self.payload: + if tag_type != tag.TYPE: + raise ValueError(f"All tags in a list must be of the same type, got tag {tag!r}") + if tag.name != "": + raise ValueError(f"All tags in a list must be unnamed, got tag {tag!r}") + + tag.write_to(buf, with_type=False, with_name=False) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ListNBT: + """Read the ListNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The ListNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + list_tag_type = ByteNBT.read_from(buf, with_type=False, with_name=False).payload + try: + length = IntNBT.read_from(buf, with_type=False, with_name=False).value + except IOError: + raise IOError("Buffer does not contain enough data to read a list.") from None + + if length < 0 or list_tag_type == NBTagType.END: + return ListNBT([], name=name) + + try: + list_tag_type = NBTagType(list_tag_type) + except ValueError: + raise TypeError(f"Unknown tag type {list_tag_type}.") from None + + list_type_class = NBTag.ASSOCIATED_TYPES.get(list_tag_type, NBTag) + if list_type_class == NBTag: + raise TypeError(f"Unknown tag type {list_tag_type}.") # pragma: no cover + try: + payload = [ + # The type is already known, so we don't need to read it again + # List items are unnamed, so we don't need to read the name + list_type_class.read_from(buf, with_type=False, with_name=False) + for _ in range(length) + ] + except IOError: + raise IOError("Buffer does not contain enough data to read the list.") from None + return ListNBT(payload, name=name) + + def __iter__(self): + """Iterate over the tags in the list.""" + yield from self.payload + + def __repr__(self) -> str: + """Get a string representation of the ListNBT tag.""" + if self.name: + return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + if len(self.payload) < 8: + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + + def to_object(self) -> Mapping[str, list[PayloadType]] | list[PayloadType]: + """Convert the ListNBT tag to a python object. + + :return: A dictionary containing the name and the list of tags. If the tag has no name, the list will be + returned directly. + """ + self.payload: list[NBTag] + if self.name: + return {self.name: [tag.to_object() for tag in self.payload]} # Extract the (unnamed) object from each tag + return [tag.to_object() for tag in self.payload] # Extract the (unnamed) object from each tag + + +class CompoundNBT(NBTag): + """NBT tag representing a compound of named tags.""" + + TYPE = NBTagType.COMPOUND + + __slots__ = () + + payload: list[NBTag] + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the CompoundNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. THis only affects the name of + the compound tag itself, not the names of the tags inside the compound. + + :note: The tags in the compound are serialized one by one, followed by an EndNBT tag. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + if not self.payload: + EndNBT().write_to(buf, with_name=False, with_type=True) + return + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway + raise ValueError( + f"All items in a compound must be NBTags. Got {self.payload!r}.\n" + "Use NBTag.from_object() to convert objects to tags first." + ) + + if not all(tag.name for tag in self.payload): + raise ValueError(f"All tags in a compound must be named, got tags {self.payload!r}") + + if len(self.payload) != len({tag.name for tag in self.payload}): # Check for duplicate names + raise ValueError("All tags in a compound must have unique names.") + + for tag in self.payload: + tag.write_to(buf) + EndNBT().write_to(buf, with_name=False, with_type=True) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> CompoundNBT: + """Read the CompoundNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The CompoundNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != cls.TYPE: + raise TypeError(f"Expected a {cls.TYPE.name} tag, but found a different tag ({tag_type.name}).") + + payload = [] + while True: + child_name, child_type = cls._read_header(buf, with_name=True, read_type=True) + if child_type == NBTagType.END: + break + # The name and type of the tag have already been read + tag = NBTag.ASSOCIATED_TYPES[child_type].read_from(buf, with_type=False, with_name=False) + tag.name = child_name + payload.append(tag) + return CompoundNBT(payload, name=name) + + def __iter__(self): + """Iterate over the tags in the compound.""" + for tag in self.payload: + yield tag.name, tag + + def __repr__(self) -> str: + """Get a string representation of the CompoundNBT tag.""" + if self.name: + return f"{self.__class__.__name__}[{self.name!r}]({dict(self)})" + return f"{self.__class__.__name__}({dict(self)})" + + def to_object(self) -> Mapping[str, Mapping[str, PayloadType]]: + """Convert the CompoundNBT tag to a python object. + + :return: A dictionary containing the name and the dictionary of tags. If the tag has no name, the dictionary + will be returned directly. + """ + result = {} + for tag in self.payload: + if tag.name in result: + raise ValueError(f"Duplicate tag name {tag.name!r} in the compound.") + if tag.name == "": + raise ValueError("All tags in a compound must have a name.") + result.update(cast("dict[str, PayloadType]", tag.to_object())) + if self.name: + return {self.name: result} + return result + + def __eq__(self, other: object) -> bool: + """Check equality between two CompoundNBT tags. + + :param other: The other CompoundNBT tag to compare to. + + :return: True if the tags are equal, False otherwise. + + :note: The order of the tags is not guaranteed, but the names of the tags must match. This function assumes + that there are no duplicate tags in the compound. + """ + # The order of the tags is not guaranteed + if not isinstance(other, NBTag): + raise NotImplementedError("Cannot compare an NBTag to a non-NBTag object.") + if self.name != other.name or self.TYPE != other.TYPE: + return False + if not isinstance(other, self.__class__): # pragma: no cover + return False # Should not happen if nobody messes with the TYPE attribute + if len(self.payload) != len(other.payload): + return False + return all(tag in other.payload for tag in self.payload) + + +class IntArrayNBT(NBTag): + """NBT tag representing an array of integers. The length of the array is stored as a signed 32-bit integer.""" + + TYPE = NBTagType.INT_ARRAY + + __slots__ = () + + payload: list[int] + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the IntArrayNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The length of the integer array is written as a signed 32-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway + raise ValueError("All items in an integer array must be integers.") + + if any(item < -(1 << 31) or item >= 1 << 31 for item in self.payload): + raise OverflowError("Integer array contains values out of range.") + + IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + for i in self.payload: + IntNBT(i).write_to(buf, with_name=False, with_type=False) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntArrayNBT: + """Read the IntArrayNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The IntArrayNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != NBTagType.INT_ARRAY: + raise TypeError(f"Expected an INT_ARRAY tag, but found a different tag ({tag_type}).") + length = IntNBT.read_from(buf, with_type=False, with_name=False).value + try: + payload = [IntNBT.read_from(buf, with_type == NBTagType.INT, with_name=False).value for _ in range(length)] + except IOError: + raise IOError( + "Buffer does not contain enough data to read the entire integer array. (Incomplete data)" + ) from None + return IntArrayNBT(payload, name=name) + + def __repr__(self) -> str: + """Get a string representation of the IntArrayNBT tag.""" + if self.name: + return f"{self.__class__.__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + if len(self.payload) < 8: + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{self.__class__.__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + + def __iter__(self): + """Iterate over the integers in the array.""" + yield from self.payload + + def to_object(self) -> Mapping[str, list[int]] | list[int]: + """Convert the IntArrayNBT tag to a python object. + + :return: A dictionary containing the name and the list of integers. If the tag has no name, the list will be + returned directly. + """ + if self.name: + return {self.name: self.payload} + return self.payload + + @property + def value(self) -> list[int]: + """Get the list of integers in the IntArrayNBT tag.""" + return self.payload + + +class LongArrayNBT(IntArrayNBT): + """NBT tag representing an array of longs. The length of the array is stored as a signed 32-bit integer.""" + + TYPE = NBTagType.LONG_ARRAY + + __slots__ = () + + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the LongArrayNBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + + :note: The length of the long array is written as a signed 32-bit integer in big-endian format. + """ + self._write_header(buf, with_type=with_type, with_name=with_name) + + if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway + raise ValueError(f"All items in a long array must be integers. ({self.payload})") + + if any(item < -(1 << 63) or item >= 1 << 63 for item in self.payload): + raise OverflowError(f"Long array contains values out of range. ({self.payload})") + + IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + for i in self.payload: + LongNBT(i).write_to(buf, with_name=False, with_type=False) + + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongArrayNBT: + """Read the LongArrayNBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag from the buffer. If this is False, the type of the class + will be used. + :param with_name: Whether to read the name of the tag to the buffer as a TAG_String. + + :return: The LongArrayNBT tag. + """ + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != NBTagType.LONG_ARRAY: + raise TypeError(f"Expected a LONG_ARRAY tag, but found a different tag ({tag_type}).") + length = IntNBT.read_from(buf, with_type=False, with_name=False).payload + + try: + payload = [LongNBT.read_from(buf, with_type=False, with_name=False).payload for _ in range(length)] + except IOError: + raise IOError( + "Buffer does not contain enough data to read the entire long array. (Incomplete data)" + ) from None + return LongArrayNBT(payload, name=name) + + +# endregion diff --git a/tests/mcproto/types/test_nbt.py b/tests/mcproto/types/test_nbt.py new file mode 100644 index 00000000..858ad506 --- /dev/null +++ b/tests/mcproto/types/test_nbt.py @@ -0,0 +1,1076 @@ +from __future__ import annotations + +import struct + +import pytest + +from mcproto.buffer import Buffer +from mcproto.types.nbt import ( + ByteArrayNBT, + ByteNBT, + CompoundNBT, + DoubleNBT, + EndNBT, + FloatNBT, + IntArrayNBT, + IntNBT, + ListNBT, + LongArrayNBT, + LongNBT, + NBTag, + NBTagType, + PayloadType, + ShortNBT, + StringNBT, +) + +# region EndNBT + + +def test_serialize_deserialize_end(): + """Test serialization/deserialization of NBT END tag.""" + output_bytes = EndNBT().serialize() + assert output_bytes == bytearray.fromhex("00") + + buffer = Buffer() + EndNBT().write_to(buffer) + assert buffer == bytearray.fromhex("00") + + buffer.clear() + EndNBT().write_to(buffer, with_name=False) + assert buffer == bytearray.fromhex("00") + + buffer = Buffer(bytearray.fromhex("00")) + assert NBTag.deserialize(buffer).TYPE == NBTagType.END + + +# endregion +# region Numerical NBT tests + + +@pytest.mark.parametrize( + ("nbt_class", "value", "expected_bytes"), + [ + (ByteNBT, 0, bytearray.fromhex("01 00")), + (ByteNBT, 1, bytearray.fromhex("01 01")), + (ByteNBT, 127, bytearray.fromhex("01 7F")), + (ByteNBT, -128, bytearray.fromhex("01 80")), + (ByteNBT, -1, bytearray.fromhex("01 FF")), + (ByteNBT, 12, bytearray.fromhex("01 0C")), + (ShortNBT, 0, bytearray.fromhex("02 00 00")), + (ShortNBT, 1, bytearray.fromhex("02 00 01")), + (ShortNBT, 32767, bytearray.fromhex("02 7F FF")), + (ShortNBT, -32768, bytearray.fromhex("02 80 00")), + (ShortNBT, -1, bytearray.fromhex("02 FF FF")), + (ShortNBT, 12, bytearray.fromhex("02 00 0C")), + (IntNBT, 0, bytearray.fromhex("03 00 00 00 00")), + (IntNBT, 1, bytearray.fromhex("03 00 00 00 01")), + (IntNBT, 2147483647, bytearray.fromhex("03 7F FF FF FF")), + (IntNBT, -2147483648, bytearray.fromhex("03 80 00 00 00")), + (IntNBT, -1, bytearray.fromhex("03 FF FF FF FF")), + (IntNBT, 12, bytearray.fromhex("03 00 00 00 0C")), + (LongNBT, 0, bytearray.fromhex("04 00 00 00 00 00 00 00 00")), + (LongNBT, 1, bytearray.fromhex("04 00 00 00 00 00 00 00 01")), + (LongNBT, (1 << 63) - 1, bytearray.fromhex("04 7F FF FF FF FF FF FF FF")), + (LongNBT, -(1 << 63), bytearray.fromhex("04 80 00 00 00 00 00 00 00")), + (LongNBT, -1, bytearray.fromhex("04 FF FF FF FF FF FF FF FF")), + (LongNBT, 12, bytearray.fromhex("04 00 00 00 00 00 00 00 0C")), + (FloatNBT, 1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 1.0))), + (FloatNBT, 3.14, bytearray.fromhex("05") + bytes(struct.pack(">f", 3.14))), + (FloatNBT, -1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", -1.0))), + (FloatNBT, 12.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 12.0))), + (DoubleNBT, 1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 1.0))), + (DoubleNBT, 3.14, bytearray.fromhex("06") + bytes(struct.pack(">d", 3.14))), + (DoubleNBT, -1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", -1.0))), + (DoubleNBT, 12.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 12.0))), + (ByteArrayNBT, b"", bytearray.fromhex("07 00 00 00 00")), + (ByteArrayNBT, b"\x00", bytearray.fromhex("07 00 00 00 01") + b"\x00"), + (ByteArrayNBT, b"\x00\x01", bytearray.fromhex("07 00 00 00 02") + b"\x00\x01"), + (ByteArrayNBT, b"\x00\x01\x02", bytearray.fromhex("07 00 00 00 03") + b"\x00\x01\x02"), + (ByteArrayNBT, b"\x00\x01\x02\x03", bytearray.fromhex("07 00 00 00 04") + b"\x00\x01\x02\x03"), + (ByteArrayNBT, b"\xFF" * 1024, bytearray.fromhex("07 00 00 04 00") + b"\xFF" * 1024), + ( + ByteArrayNBT, + bytes((n - 1) * n * 2 % 256 for n in range(256)), + bytearray.fromhex("07 00 00 01 00") + bytes((n - 1) * n * 2 % 256 for n in range(256)), + ), + (StringNBT, "", bytearray.fromhex("08 00 00")), + (StringNBT, "test", bytearray.fromhex("08 00 04") + b"test"), + (StringNBT, "a" * ((1 << 15) - 1), bytearray.fromhex("08 7F FF") + b"a" * ((1 << 15) - 1)), + (StringNBT, "&à@é", bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8")), + (ListNBT, [], bytearray.fromhex("09 00 00 00 00 00")), + (ListNBT, [ByteNBT(0)], bytearray.fromhex("09 01 00 00 00 01 00")), + (ListNBT, [ShortNBT(127), ShortNBT(256)], bytearray.fromhex("09 02 00 00 00 02 00 7F 01 00")), + ( + ListNBT, + [ListNBT([ByteNBT(0)]), ListNBT([IntNBT(256)])], + bytearray.fromhex("09 09 00 00 00 02 01 00 00 00 01 00 03 00 00 00 01 00 00 01 00"), + ), + (CompoundNBT, [], bytearray.fromhex("0A 00")), + ( + CompoundNBT, + [ByteNBT(0, name="test")], + bytearray.fromhex("0A") + ByteNBT(0, name="test").serialize() + b"\x00", + ), + ( + CompoundNBT, + [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], + bytearray.fromhex("0A") + ShortNBT(128, "Short").serialize() + ByteNBT(-1, "Byte").serialize() + b"\x00", + ), + ( + CompoundNBT, + [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], + bytearray.fromhex("0A") + CompoundNBT([ByteNBT(0, name="Byte")], name="test").serialize() + b"\x00", + ), + ( + CompoundNBT, + [CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test"), IntNBT(-1, "Int 2")], + bytearray.fromhex("0A") + + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test").serialize() + + IntNBT(-1, "Int 2").serialize() + + b"\x00", + ), + (IntArrayNBT, [], bytearray.fromhex("0B 00 00 00 00")), + (IntArrayNBT, [0], bytearray.fromhex("0B 00 00 00 01 00 00 00 00")), + (IntArrayNBT, [0, 1], bytearray.fromhex("0B 00 00 00 02 00 00 00 00 00 00 00 01")), + (IntArrayNBT, [1, 2, 3], bytearray.fromhex("0B 00 00 00 03 00 00 00 01 00 00 00 02 00 00 00 03")), + (IntArrayNBT, [(1 << 31) - 1], bytearray.fromhex("0B 00 00 00 01 7F FF FF FF")), + (IntArrayNBT, [(1 << 31) - 1, (1 << 31) - 2], bytearray.fromhex("0B 00 00 00 02 7F FF FF FF 7F FF FF FE")), + (IntArrayNBT, [-1, -2, -3], bytearray.fromhex("0B 00 00 00 03 FF FF FF FF FF FF FF FE FF FF FF FD")), + (IntArrayNBT, [12] * 1024, bytearray.fromhex("0B 00 00 04 00") + b"\x00\x00\x00\x0C" * 1024), + (LongArrayNBT, [], bytearray.fromhex("0C 00 00 00 00")), + (LongArrayNBT, [0], bytearray.fromhex("0C 00 00 00 01 00 00 00 00 00 00 00 00")), + (LongArrayNBT, [0, 1], bytearray.fromhex("0C 00 00 00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01")), + ( + LongArrayNBT, + [1, 2, 3], + bytearray.fromhex( + "0C 00 00 00 03 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03" + ), + ), + (LongArrayNBT, [(1 << 63) - 1], bytearray.fromhex("0C 00 00 00 01 7F FF FF FF FF FF FF FF")), + ( + LongArrayNBT, + [(1 << 63) - 1, (1 << 63) - 2], + bytearray.fromhex("0C 00 00 00 02 7F FF FF FF FF FF FF FF 7F FF FF FF FF FF FF FE"), + ), + ( + LongArrayNBT, + [-1, -2, -3], + bytearray.fromhex( + "0C 00 00 00 03 FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FE FF FF FF FF FF FF FF FD" + ), + ), + (LongArrayNBT, [12] * 1024, bytearray.fromhex("0C 00 00 04 00") + b"\x00\x00\x00\x00\x00\x00\x00\x0C" * 1024), + ], +) +def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType, expected_bytes: bytes): + """Test serialization/deserialization of NBT tag without name.""" + # Test serialization + output_bytes = nbt_class(value).serialize(with_name=False) + output_bytes_no_type = nbt_class(value).serialize(with_type=False, with_name=False) + assert output_bytes == expected_bytes + assert output_bytes_no_type == expected_bytes[1:] + + buffer = Buffer() + nbt_class(value).write_to(buffer, with_name=False) + assert buffer == expected_bytes + + # Test deserialization + buffer = Buffer(expected_bytes) + assert NBTag.deserialize(buffer, with_name=False) == nbt_class(value) + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.deserialize(buffer, with_type=False, with_name=False) == nbt_class(value) + + buffer = Buffer(expected_bytes) + assert nbt_class.read_from(buffer, with_name=False) == nbt_class(value) + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.read_from(buffer, with_type=False, with_name=False) == nbt_class(value) + + +@pytest.mark.parametrize( + ("nbt_class", "value", "name", "expected_bytes"), + [ + (ByteNBT, 0, "test", bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("00")), + (ByteNBT, 1, "a", bytearray.fromhex("01") + b"\x00\x01a" + bytearray.fromhex("01")), + (ByteNBT, 127, "&à@é", bytearray.fromhex("01 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F")), + (ByteNBT, -128, "test", bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("80")), + (ByteNBT, 12, "a" * 100, bytearray.fromhex("01") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("0C")), + (ShortNBT, 0, "test", bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("00 00")), + (ShortNBT, 1, "a", bytearray.fromhex("02") + b"\x00\x01a" + bytearray.fromhex("00 01")), + (ShortNBT, 32767, "&à@é", bytearray.fromhex("02 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF")), + (ShortNBT, -32768, "test", bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("80 00")), + (ShortNBT, 12, "a" * 100, bytearray.fromhex("02") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 0C")), + (IntNBT, 0, "test", bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + (IntNBT, 1, "a", bytearray.fromhex("03") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01")), + ( + IntNBT, + 2147483647, + "&à@é", + bytearray.fromhex("03 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF"), + ), + (IntNBT, -2147483648, "test", bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00")), + ( + IntNBT, + 12, + "a" * 100, + bytearray.fromhex("03") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 0C"), + ), + (LongNBT, 0, "test", bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00 00 00 00")), + (LongNBT, 1, "a", bytearray.fromhex("04") + b"\x00\x01a" + bytearray.fromhex("00 00 00 00 00 00 00 01")), + ( + LongNBT, + (1 << 63) - 1, + "&à@é", + bytearray.fromhex("04 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF FF FF FF FF"), + ), + ( + LongNBT, + -1 << 63, + "test", + bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00 00 00 00 00"), + ), + ( + LongNBT, + 12, + "a" * 100, + bytearray.fromhex("04") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 00 00 00 00 0C"), + ), + (FloatNBT, 1.0, "test", bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 1.0))), + (FloatNBT, 3.14, "a", bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 3.14))), + ( + FloatNBT, + -1.0, + "&à@é", + bytearray.fromhex("05 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0)), + ), + (FloatNBT, 12.0, "test", bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 12.0))), + (DoubleNBT, 1.0, "test", bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 1.0))), + (DoubleNBT, 3.14, "a", bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 3.14))), + ( + DoubleNBT, + -1.0, + "&à@é", + bytearray.fromhex("06 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0)), + ), + (DoubleNBT, 12.0, "test", bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 12.0))), + (ByteArrayNBT, b"", "test", bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + ByteArrayNBT, + b"\x00", + "a", + bytearray.fromhex("07") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00", + ), + ( + ByteArrayNBT, + b"\x00\x01", + "&à@é", + bytearray.fromhex("07 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("00 00 00 02") + b"\x00\x01", + ), + ( + ByteArrayNBT, + b"\x00\x01\x02", + "test", + bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 03") + b"\x00\x01\x02", + ), + ( + ByteArrayNBT, + b"\xFF" * 1024, + "a" * 100, + bytearray.fromhex("07") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 04 00") + b"\xFF" * 1024, + ), + (StringNBT, "", "test", bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 00")), + (StringNBT, "test", "a", bytearray.fromhex("08") + b"\x00\x01a" + bytearray.fromhex("00 04") + b"test"), + ( + StringNBT, + "a" * 100, + "&à@é", + bytearray.fromhex("08 00 06") + + bytes("&à@é", "utf-8") + + bytearray.fromhex("00 64") + + b"a" * 100, + ), + ( + StringNBT, + "&à@é", + "test", + bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 06") + bytes("&à@é", "utf-8"), + ), + (ListNBT, [], "test", bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00")), + ( + ListNBT, + [ByteNBT(-1)], + "a", + bytearray.fromhex("09") + b"\x00\x01a" + bytearray.fromhex("01 00 00 00 01 FF"), + ), + ( + ListNBT, + [ShortNBT(127), ShortNBT(256)], + "test", + bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("02 00 00 00 02 00 7F 01 00"), + ), + ( + ListNBT, + [ListNBT([ByteNBT(-1)]), ListNBT([IntNBT(256)])], + "a", + bytearray.fromhex("09") + + b"\x00\x01a" + + bytearray.fromhex("09 00 00 00 02 01 00 00 00 01 FF 03 00 00 00 01 00 00 01 00"), + ), + (CompoundNBT, [], "test", bytearray.fromhex("0A") + b"\x00\x04test" + bytearray.fromhex("00")), + ( + CompoundNBT, + [ByteNBT(0, name="Byte")], + "test", + bytearray.fromhex("0A") + b"\x00\x04test" + ByteNBT(0, name="Byte").serialize() + b"\x00", + ), + ( + CompoundNBT, + [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], + "test", + bytearray.fromhex("0A") + + b"\x00\x04test" + + ShortNBT(128, "Short").serialize() + + ByteNBT(-1, "Byte").serialize() + + b"\x00", + ), + ( + CompoundNBT, + [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], + "test", + bytearray.fromhex("0A") + + b"\x00\x04test" + + CompoundNBT([ByteNBT(0, name="Byte")], "test").serialize() + + b"\x00", + ), + ( + CompoundNBT, + [ListNBT([ByteNBT(0)], name="List")], + "test", + bytearray.fromhex("0A") + b"\x00\x04test" + ListNBT([ByteNBT(0)], name="List").serialize() + b"\x00", + ), + (IntArrayNBT, [], "test", bytearray.fromhex("0B") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + IntArrayNBT, + [0], + "a", + bytearray.fromhex("0B") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00\x00\x00\x00", + ), + ( + IntArrayNBT, + [0, 1], + "&à@é", + bytearray.fromhex("0B 00 06") + + bytes("&à@é", "utf-8") + + bytearray.fromhex("00 00 00 02") + + b"\x00\x00\x00\x00\x00\x00\x00\x01", + ), + ( + IntArrayNBT, + [1, 2, 3], + "test", + bytearray.fromhex("0B") + + b"\x00\x04test" + + bytearray.fromhex("00 00 00 03") + + b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03", + ), + ( + IntArrayNBT, + [(1 << 31) - 1], + "a" * 100, + bytearray.fromhex("0B") + + b"\x00\x64" + + b"a" * 100 + + bytearray.fromhex("00 00 00 01") + + b"\x7F\xFF\xFF\xFF", + ), + (LongArrayNBT, [], "test", bytearray.fromhex("0C") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00")), + ( + LongArrayNBT, + [0], + "a", + bytearray.fromhex("0C") + + b"\x00\x01a" + + bytearray.fromhex("00 00 00 01") + + b"\x00\x00\x00\x00\x00\x00\x00\x00", + ), + ( + LongArrayNBT, + [0, 1], + "&à@é", + bytearray.fromhex("0C 00 06") + + bytes("&à@é", "utf-8") + + bytearray.fromhex("00 00 00 02") + + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + ), + ( + LongArrayNBT, + [1, 2, 3], + "test", + bytearray.fromhex("0C") + + b"\x00\x04test" + + bytearray.fromhex("00 00 00 03") + + bytearray.fromhex("00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03"), + ), + ( + LongArrayNBT, + [(1 << 63) - 1] * 32768, + "a" * 100, + bytearray.fromhex("0C") + + b"\x00\x64" + + b"a" * 100 + + bytearray.fromhex("00 00 80 00") + + b"\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF" * 32768, + ), + ], +) +def test_serialize_deserialize(nbt_class: type[NBTag], value: PayloadType, name: str, expected_bytes: bytes): + """Test serialization/deserialization of NBT tag with name.""" + # Test serialization + output_bytes = nbt_class(value, name).serialize() + output_bytes_no_type = nbt_class(value, name).serialize(with_type=False) + assert output_bytes == expected_bytes + assert output_bytes_no_type == expected_bytes[1:] + + buffer = Buffer() + nbt_class(value, name).write_to(buffer) + assert buffer == expected_bytes + + # Test deserialization + buffer = Buffer(expected_bytes * 2) + assert buffer.remaining == len(expected_bytes) * 2 + assert NBTag.deserialize(buffer) == nbt_class(value, name=name) + assert buffer.remaining == len(expected_bytes) + assert NBTag.deserialize(buffer) == nbt_class(value, name=name) + assert buffer.remaining == 0 + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.deserialize(buffer, with_type=False) == nbt_class(value, name=name) + + buffer = Buffer(expected_bytes) + assert nbt_class.read_from(buffer) == nbt_class(value, name=name) + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.read_from(buffer, with_type=False) == nbt_class(value, name=name) + + +@pytest.mark.parametrize( + ("nbt_class", "size", "tag"), + [ + (ByteNBT, 8, NBTagType.BYTE), + (ShortNBT, 16, NBTagType.SHORT), + (IntNBT, 32, NBTagType.INT), + (LongNBT, 64, NBTagType.LONG), + ], +) +def test_serialize_deserialize_numerical_fail(nbt_class: type[NBTag], size: int, tag: NBTagType): + """Test serialization/deserialization of NBT NUM tag with invalid value.""" + # Out of bounds + with pytest.raises(OverflowError): + nbt_class(1 << (size - 1)).serialize(with_name=False) + + with pytest.raises(OverflowError): + nbt_class(-(1 << (size - 1)) - 1).serialize(with_name=False) + + with pytest.raises(ValueError): # No name + nbt_class(0, "").serialize() # without with_name=False + + # Deserialization + buffer = Buffer(bytearray([tag.value + 1] + [0] * (size // 8))) + with pytest.raises(TypeError): # Tries to read a nbt_class, but it's one higher + nbt_class.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([tag.value] + [0] * ((size // 8) - 1))) + with pytest.raises(IOError): + nbt_class.read_from(buffer, with_name=False) + + buffer = Buffer(bytearray([tag.value, 0, 0] + [0] * (size // 8))) + assert nbt_class.read_from(buffer, with_name=True) == nbt_class(0) + + +# endregion + +# region FloatNBT + + +def test_serialize_deserialize_float_fail(): + """Test serialization/deserialization of NBT FLOAT tag with invalid value.""" + with pytest.raises(ValueError): + FloatNBT(0, 0).serialize() # type:ignore + + with pytest.raises(struct.error): + FloatNBT("test").serialize(with_name=False) + + with pytest.raises(OverflowError): + FloatNBT(1e39, "test").serialize() + + with pytest.raises(OverflowError): + FloatNBT(-1e39, "test").serialize() + + # Deserialization + buffer = Buffer(bytearray([NBTagType.BYTE] + [0] * 4)) + with pytest.raises(TypeError): # Tries to read a FloatNBT, but it's a ByteNBT + FloatNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.FLOAT, 0, 0, 0])) + with pytest.raises(IOError): + FloatNBT.read_from(buffer, with_name=False) + + +# endregion +# region DoubleNBT + + +def test_serialize_deserialize_double_fail(): + """Test serialization/deserialization of NBT DOUBLE tag with invalid value.""" + with pytest.raises(ValueError): + DoubleNBT(0, 0).serialize() # type: ignore + + with pytest.raises(struct.error): + DoubleNBT("test").serialize(with_name=False) + + # Deserialization + buffer = Buffer(bytearray([0x01] + [0] * 8)) + with pytest.raises(TypeError): # Tries to read a DoubleNBT, but it's a ByteNBT + DoubleNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.DOUBLE, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(IOError): + DoubleNBT.read_from(buffer, with_name=False) + + +# endregion +# region ByteArrayNBT + + +def test_serialize_deserialize_bytearray_fail(): + """Test serialization/deserialization of NBT BYTEARRAY tag with invalid value.""" + with pytest.raises(ValueError): + ByteArrayNBT([], 0).serialize() # type:ignore + + with pytest.raises(ValueError): + ByteArrayNBT(b"test", "").serialize() + + # Deserialization + buffer = Buffer(bytearray([0x01] + [0] * 4)) + with pytest.raises(TypeError): # Tries to read a ByteArrayNBT, but it's a ByteNBT + ByteArrayNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0])) # Missing length bytes + with pytest.raises(IOError): + ByteArrayNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 1])) # Missing data bytes + with pytest.raises(IOError): + ByteArrayNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 2, 0])) # Missing data bytes + with pytest.raises(IOError): + ByteArrayNBT.read_from(buffer, with_name=False) + + # Negative length + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0xFF, 0xFF, 0xFF, 0xFF])) # length = -1 + with pytest.raises(ValueError): + ByteArrayNBT.deserialize(buffer, with_name=False) + + +# endregion +# region StringNBT + + +def test_serialize_deserialize_string_fail(): + """Test serialization/deserialization of NBT STRING tag with invalid value.""" + with pytest.raises(ValueError): + StringNBT("", 0).serialize() # type:ignore + + with pytest.raises(ValueError): + StringNBT("test", "").serialize() + + # Deserialization + buffer = Buffer(bytearray([0x01, 0, 0])) + with pytest.raises(TypeError): # Tries to read a StringNBT, but it's a ByteNBT + StringNBT.deserialize(buffer, with_name=False) + + # Not enough data for the length + buffer = Buffer(bytearray([NBTagType.STRING, 0])) + with pytest.raises(IOError): + StringNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.STRING, 0, 1])) + with pytest.raises(IOError): + StringNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.STRING, 0, 2, 0])) + with pytest.raises(IOError): + StringNBT.read_from(buffer, with_name=False) + + # Negative length + buffer = Buffer(bytearray([NBTagType.STRING, 0xFF, 0xFF])) # length = -1 + with pytest.raises(ValueError): + StringNBT.deserialize(buffer, with_name=False) + + # Invalid UTF-8 + buffer = Buffer(bytearray([NBTagType.STRING, 0, 1, 0xC0, 0x80])) + with pytest.raises(UnicodeDecodeError): + StringNBT.read_from(buffer, with_name=False) + + +# endregion +# region ListNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([ByteNBT(0), IntNBT(0)], ValueError), + ([ByteNBT(0), "test"], ValueError), + ([ByteNBT(0), None], ValueError), + ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags + ([ByteNBT(128), ByteNBT(-1)], OverflowError), # Check for error propagation + ], +) +def test_serialize_list_fail(payload, error): + """Test serialization of NBT LIST tag with invalid value.""" + with pytest.raises(error): + ListNBT(payload, "test").serialize() + + +def test_deserialize_list_fail(): + """Test deserialization of NBT LIST tag with invalid value.""" + # Wrong tag type + buffer = Buffer(bytearray([0x09, 255, 0, 0, 0, 1, 0])) + with pytest.raises(TypeError): + ListNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([0x09, 1, 0, 0, 0, 1])) + with pytest.raises(IOError): + ListNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([0x09, 1, 0, 0, 0])) + with pytest.raises(IOError): + ListNBT.read_from(buffer, with_name=False) + + +# endregion +# region CompoundNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([ByteNBT(0, name="Hello"), IntNBT(0)], ValueError), + ([ByteNBT(0, name="hi"), "test"], ValueError), + ([ByteNBT(0, name="hi"), None], ValueError), + ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags + ([ByteNBT(128, name="Jello"), ByteNBT(-1, name="Bonjour")], OverflowError), # Check for error propagation + ], +) +def test_serialize_compound_fail(payload, error): + """Test serialization of NBT COMPOUND tag with invalid value.""" + with pytest.raises(error): + CompoundNBT(payload, "test").serialize() + + # Double name + with pytest.raises(ValueError): + CompoundNBT([ByteNBT(0, name="test"), ByteNBT(0, name="test")], "comp").serialize() + + +def test_deseialize_compound_fail(): + """Test deserialization of NBT COMPOUND tag with invalid value.""" + # Not enough data + buffer = Buffer(bytearray([NBTagType.COMPOUND, 0x01])) + with pytest.raises(IOError): + CompoundNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.COMPOUND])) + with pytest.raises(IOError): + CompoundNBT.read_from(buffer, with_name=False) + + # Wrong tag type + buffer = Buffer(bytearray([15])) + with pytest.raises(TypeError): + NBTag.deserialize(buffer) + + +def test_to_object_compound(): + """Try a few incorrect CompoundNBT.to_object() calls.""" + comp = CompoundNBT([ByteNBT(0, "test"), ByteNBT(1, "test")]) + with pytest.raises(ValueError): + comp.to_object() # Duplicate name + + comp = CompoundNBT([ByteNBT(0), ByteNBT(1)]) + with pytest.raises(ValueError): + comp.to_object() + + +def test_equality_compound(): + """Test equality of CompoundNBT.""" + comp1 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp") + assert comp1 == comp2 + + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2")], "comp") + assert comp1 != comp2 + + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], "comp") + assert comp1 != comp2 + + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], "comp2") + assert comp1 != comp2 + + assert comp1 != ByteNBT(0, name="comp") + + +# endregion +# region IntArrayNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([0, "test"], ValueError), + ([0, None], ValueError), + ([0, 1 << 31], OverflowError), + ([0, -(1 << 31) - 1], OverflowError), + ], +) +def test_serialize_intarray_fail(payload, error): + """Test serialization of NBT INTARRAY tag with invalid value.""" + with pytest.raises(error): + IntArrayNBT(payload, "test").serialize() + + +def test_deserialize_intarray_fail(): + """Test deserialization of NBT INTARRAY tag with invalid value.""" + # Not enough data for 1 element + buffer = Buffer(bytearray([0x0B, 0, 0, 0, 1, 0, 0, 0])) + with pytest.raises(IOError): + IntArrayNBT.deserialize(buffer, with_name=False) + + # Not enough data for the size + buffer = Buffer(bytearray([0x0B, 0, 0, 0])) + with pytest.raises(IOError): + IntArrayNBT.read_from(buffer, with_name=False) + + # Not enough data to start the 2nd element + buffer = Buffer(bytearray([0x0B, 0, 0, 0, 2, 1, 0, 0, 0])) + with pytest.raises(IOError): + IntArrayNBT.read_from(buffer, with_name=False) + + +# endregion +# region LongArrayNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([0, "test"], ValueError), + ([0, None], ValueError), + ([0, 1 << 63], OverflowError), + ([0, -(1 << 63) - 1], OverflowError), + ], +) +def test_serialize_deserialize_longarray_fail(payload, error): + """Test serialization/deserialization of NBT LONGARRAY tag with invalid value.""" + with pytest.raises(error): + LongArrayNBT(payload, "test").serialize() + + +def test_deserialize_longarray_fail(): + """Test deserialization of NBT LONGARRAY tag with invalid value.""" + # Not enough data for 1 element + buffer = Buffer(bytearray([0x0C, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(IOError): + LongArrayNBT.deserialize(buffer, with_name=False) + + # Not enough data for the size + buffer = Buffer(bytearray([0x0C, 0, 0, 0])) + with pytest.raises(IOError): + LongArrayNBT.read_from(buffer, with_name=False) + + # Not enough data to start the 2nd element + buffer = Buffer(bytearray([0x0C, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(IOError): + LongArrayNBT.read_from(buffer, with_name=False) + + +# endregion + +# region NBTag + + +def test_nbt_helloworld(): + """Test serialization/deserialization of a simple NBT tag. + + Source data: https://wiki.vg/NBT#Example. + """ + data = bytearray.fromhex("0a000b68656c6c6f20776f726c640800046e616d65000942616e616e72616d6100") + buffer = Buffer(data) + + expected_object = { + "hello world": { + "name": "Bananrama", + } + } + + data = CompoundNBT.deserialize(buffer) + assert data == NBTag.from_object(expected_object) + assert data.to_object() == expected_object + + +def test_nbt_bigfile(): + """Test serialization/deserialization of a big NBT tag. + + Slighly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. + Source data: https://wiki.vg/NBT#Example. + """ + data = "0a00054c6576656c0400086c6f6e67546573747fffffffffffffff02000973686f7274546573747fff08000a737472696e6754657374002948454c4c4f20574f524c4420544849532049532041205445535420535452494e4720c385c384c39621050009666c6f6174546573743eff1832030007696e74546573747fffffff0a00146e657374656420636f6d706f756e6420746573740a000368616d0800046e616d65000648616d70757305000576616c75653f400000000a00036567670800046e616d6500074567676265727405000576616c75653f00000000000c000f6c6973745465737420286c6f6e672900000005000000000000000b000000000000000c000000000000000d000000000000000e7fffffffffffffff0b000e6c697374546573742028696e7429000000047fffffff7ffffffe7ffffffd7ffffffc0900136c697374546573742028636f6d706f756e64290a000000020800046e616d65000f436f6d706f756e642074616720233004000a637265617465642d6f6e000001265237d58d000800046e616d65000f436f6d706f756e642074616720233104000a637265617465642d6f6e000001265237d58d0001000862797465546573747f07006562797465417272617954657374202874686520666972737420313030302076616c756573206f6620286e2a6e2a3235352b6e2a3729253130302c207374617274696e672077697468206e3d302028302c2036322c2033342c2031362c20382c202e2e2e2929000003e8003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a063005000a646f75626c65546573743efc7b5e00" # noqa: E501 + data = bytearray.fromhex(data) + buffer = Buffer(data) + + expected_object = { + "Level": { + "longTest": 9223372036854775807, + "shortTest": 32767, + "stringTest": "HELLO WORLD THIS IS A TEST STRING ÅÄÖ!", + "floatTest": 0.4982314705848694, + "intTest": 2147483647, + "nested compound test": { + "ham": {"name": "Hampus", "value": 0.75}, + "egg": {"name": "Eggbert", "value": 0.5}, + }, + "listTest (long)": [11, 12, 13, 14, 9223372036854775807], + "listTest (int)": [2147483647, 2147483646, 2147483645, 2147483644], + "listTest (compound)": [ + {"name": "Compound tag #0", "created-on": 1264099775885}, + {"name": "Compound tag #1", "created-on": 1264099775885}, + ], + "byteTest": 127, + "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100" + ", starting with n=0 (0, 62, 34, 16, 8, ...))": bytearray( + (n * n * 255 + n * 7) % 100 for n in range(1000) + ), + "doubleTest": 0.4931287132182315, + } + } + + data = CompoundNBT.deserialize(buffer) + # print(f"{data=}\n{expected_object=}\n{data.to_object()=}\n{NBTag.from_object(expected_object)=}") + + def check_equality(self, other): + """Check if two objects are equal, with deep epsilon check for floats.""" + if type(self) != type(other): + return False + if isinstance(self, dict): + if len(self) != len(other): + return False + for key in self: + if key not in other: + return False + if not check_equality(self[key], other[key]): + return False + return True + if isinstance(self, list): + if len(self) != len(other): + return False + return all(check_equality(self[i], other[i]) for i in range(len(self))) + if isinstance(self, float): + return abs(self - other) < 1e-6 + if self != other: + return False + return self == other + + assert data == NBTag.from_object(expected_object) + assert check_equality(data.to_object(), expected_object) + + +# endregion +# region Edge cases + + +def test_from_object_lst_not_same_type(): + """Test from_object with a list that does not have the same type.""" + with pytest.raises(TypeError): + NBTag.from_object([ByteNBT(0), IntNBT(0)]) + + +def test_from_object_out_of_bounds(): + """Test from_object with a value that is out of bounds.""" + with pytest.raises(ValueError): + NBTag.from_object({"test": 1 << 63}) + + with pytest.raises(ValueError): + NBTag.from_object({"test": -(1 << 63) - 1}) + + with pytest.raises(ValueError): + NBTag.from_object({"test": [1 << 63]}) + + with pytest.raises(ValueError): + NBTag.from_object({"test": [-(1 << 63) - 1]}) + + +def test_from_object_morecases(): + """Test from_object with more edge cases.""" + + class CustomType: + def __bytes__(self): + return b"test" + + assert NBTag.from_object( + { + "nbtag": ByteNBT(0), # ByteNBT + "bytearray": b"test", # Conversion from bytes + "empty_list": [], # Empty list with type EndNBT + "empty_compound": {}, # Empty compound + "end_NBTag": None, # Should not be done in practice, would create a broken buffer if serialized + "custom": CustomType(), # Custom type with __bytes__ method + } + ) == CompoundNBT( + [ # Order is shuffled because the spec does not require a specific order + CompoundNBT([], "empty_compound"), + ByteArrayNBT(b"test", "bytearray"), + ByteArrayNBT(b"test", "custom"), + ListNBT([], "empty_list"), + ByteNBT(0, "nbtag"), + EndNBT(), + ] + ) + + # Not a valid object + with pytest.raises(TypeError): + NBTag.from_object({"test": object()}) + + compound = CompoundNBT.from_object( + { + "test": ByteNBT(0), + "test2": IntNBT(0), + }, + name="compound", + ) + assert compound["test"] == ByteNBT(0, "test") + assert compound["test2"] == IntNBT(0, "test2") + with pytest.raises(KeyError): + compound["test3"] + + # Cannot index into a ByteNBT + with pytest.raises(TypeError): + compound["test"][0] # type:ignore + + listnbt = ListNBT.from_object([0, 1, 2], use_int_array=False) + assert listnbt[0] == ByteNBT(0) + assert listnbt[1] == ByteNBT(1) + assert listnbt[2] == ByteNBT(2) + with pytest.raises(IndexError): + listnbt[3] + with pytest.raises(TypeError): + listnbt["hello"] + + assert listnbt[-1] == ByteNBT(2) + assert listnbt[-2] == ByteNBT(1) + assert listnbt[-3] == ByteNBT(0) + + with pytest.raises(TypeError): + listnbt[object()] # type:ignore + + assert listnbt.value == [0, 1, 2] + assert listnbt.to_object() == [0, 1, 2] + assert ListNBT([]).value == [] + assert compound.to_object() == {"compound": {"test": 0, "test2": 0}} + assert compound.value == {"test": 0, "test2": 0} + assert ListNBT([IntNBT(0)]).value == [0] + + assert ByteNBT(12).value == 12 + assert ShortNBT(13).value == 13 + assert IntNBT(14).value == 14 + assert LongNBT(15).value == 15 + assert FloatNBT(0.5).value == 0.5 + assert DoubleNBT(0.6).value == 0.6 + assert ByteArrayNBT(b"test").value == b"test" + assert StringNBT("test").value == "test" + assert IntArrayNBT([0, 1, 2]).value == [0, 1, 2] + assert LongArrayNBT([0, 1, 2, 3]).value == [0, 1, 2, 3] + + invalid = ListNBT("Hello", "name") + with pytest.raises(AttributeError): + invalid[0] + + invalid = CompoundNBT([ByteNBT(0, "Byte"), "Hi"], "name") + with pytest.raises(AttributeError): + invalid["Byte"] # Attribute error is raised when the structure is incorrectly constructed + + +def test_to_object_morecases(): + """Test to_object with more edge cases.""" + + class CustomType: + def __bytes__(self): + return b"test" + + assert NBTag.from_object( + { + "bytearray": b"test", + "empty_list": [], + "empty_compound": {}, + "custom": CustomType(), + } + ).to_object() == { + "bytearray": b"test", + "empty_list": [], + "empty_compound": {}, + "custom": b"test", + } + + assert NBTag.to_object(CompoundNBT([])) == {} + + assert EndNBT().to_object() == {} # Does not add anything when doing dict.update + assert FloatNBT(0.5).to_object() == 0.5 + assert FloatNBT(0.5, "Hello World").to_object() == {"Hello World": 0.5} + assert ByteArrayNBT(b"test").to_object() == b"test" # Do not add name when there is no name + assert StringNBT("test").to_object() == "test" + assert StringNBT("test", "name").to_object() == {"name": "test"} + assert ListNBT([ByteNBT(0), ByteNBT(1)]).to_object() == [0, 1] + assert ListNBT([ByteNBT(0), ByteNBT(1)], "name").to_object() == {"name": [0, 1]} + assert IntArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] + assert LongArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] + + +def test_data_conversions(): + """Test data conversions using the built-in functions.""" + assert int(IntNBT(-1)) == -1 + assert float(FloatNBT(0.5)) == 0.5 + assert str(StringNBT("test")) == "test" + assert bytes(ByteArrayNBT(b"test")) == b"test" + assert list(ListNBT([ByteNBT(0), ByteNBT(1)])) == [ByteNBT(0), ByteNBT(1)] + assert dict(CompoundNBT([ByteNBT(0, "first"), ByteNBT(1, "second")])) == { + "first": ByteNBT(0, "first"), + "second": ByteNBT(1, "second"), + } + assert list(IntArrayNBT([0, 1, 2])) == [0, 1, 2] + assert list(LongArrayNBT([0, 1, 2])) == [0, 1, 2] + + +def test_init_nbtag_directly(): + """Test initializing NBTag directly.""" + with pytest.raises(TypeError): + NBTag(0) + with pytest.raises(TypeError): + NBTag(0, "test") + with pytest.raises(TypeError): + NBTag(0, name="test") + + +# endregion