diff --git a/changes/257.feature.md b/changes/257.feature.md new file mode 100644 index 00000000..453b22bf --- /dev/null +++ b/changes/257.feature.md @@ -0,0 +1,12 @@ +- Added the `NBTag` to deal with NBT data: + - The `NBTag` class is the base class for all NBT tags and provides the basic functionality to serialize and deserialize NBT data from and to a `Buffer` object. + - The classes `EndNBT`, `ByteNBT`, `ShortNBT`, `IntNBT`, `LongNBT`, `FloatNBT`, `DoubleNBT`, `ByteArrayNBT`, `StringNBT`, `ListNBT`, `CompoundNBT`, `IntArrayNBT`and `LongArrayNBT` were added and correspond to the NBT types described in the [NBT specification](https://wiki.vg/NBT#Specification). + - NBT tags can be created using the `NBTag.from_object()` method and a schema that describes the NBT tag structure. + Compound tags are represented as dictionaries, list tags as lists, and primitive tags as their respective Python types. + The implementation allows to add custom classes to the schema to handle custom NBT tags if they inherit the `:class: NBTagConvertible` class. + - The `NBTag.to_object()` method can be used to convert an NBT tag back to a Python object. Use include_schema=True to include the schema in the output, and `include_name=True` to include the name of the tag in the output. In that case the output will be a dictionary with a single key that is the name of the tag and the value is the object representation of the tag. + - The `NBTag.serialize()` can be used to serialize an NBT tag to a new `Buffer` object. + - The `NBTag.deserialize(buffer)` can be used to deserialize an NBT tag from a `Buffer` object. + - If the buffer already exists, the `NBTag.write_to(buffer, with_type=True, with_name=True)` method can be used to write the NBT tag to the buffer (and in that case with the type and name in the right format). + - The `NBTag.read_from(buffer, with_type=True, with_name=True)` method can be used to read an NBT tag from the buffer (and in that case with the type and name in the right format). + - The `NBTag.value` property can be used to get the value of the NBT tag as a Python object. diff --git a/mcproto/types/nbt.py b/mcproto/types/nbt.py new file mode 100644 index 00000000..cc529a86 --- /dev/null +++ b/mcproto/types/nbt.py @@ -0,0 +1,1147 @@ +from __future__ import annotations + +from abc import abstractmethod +from enum import IntEnum +from typing import Iterator, List, Mapping, Sequence, Tuple, Type, Union, cast + +from typing_extensions import TypeAlias, override +from typing import Protocol, runtime_checkable # Have to be imported from the same place + +from mcproto.buffer import Buffer +from mcproto.protocol.base_io import StructFormat +from mcproto.types.abc import MCType + +__all__ = [ + "NBTagConvertible", + "NBTagType", + "NBTag", + "EndNBT", + "ByteNBT", + "ShortNBT", + "IntNBT", + "LongNBT", + "FloatNBT", + "DoubleNBT", + "ByteArrayNBT", + "StringNBT", + "ListNBT", + "CompoundNBT", + "IntArrayNBT", + "LongArrayNBT", +] + +""" +Implementation of the NBT (Named Binary Tag) format used in Minecraft as described in the NBT specification + +Source : `Minecraft NBT Spec `_ + +Named Binary Tag specification + +NBT (Named Binary Tag) is a tag based binary format designed to carry large amounts of binary data with smaller +amounts of additional data. +An NBT file consists of a single GZIPped Named Tag of type TAG_Compound. + +A Named Tag has the following format: + + byte tagType + TAG_String name + [payload] + +The tagType is a single byte defining the contents of the payload of the tag. + +The name is a descriptive name, and can be anything (eg "cat", "banana", "Hello World!"). It has nothing to do with +the tagType. +The purpose for this name is to name tags so parsing is easier and can be made to only look for certain recognized +tag names. +Exception: If tagType is TAG_End, the name is skipped and assumed to be "". + +The [payload] varies by tagType. + +Note that ONLY Named Tags carry the name and tagType data. Explicitly identified Tags (such as TAG_String above) +only contains the payload. + +.. seealso:: :class:`NBTagType` + +""" +# region NBT Specification + + +class NBTagType(IntEnum): + """Enumeration of the different types of NBT tags. + + TYPE: 0 NAME: TAG_End + Payload: None. + Note: This tag is used to mark the end of a list. + Cannot be named! If type 0 appears where a Named Tag is expected, the name is assumed to be "". + (In other words, this Tag is always just a single 0 byte when named, and nothing in all other cases) + + TYPE: 1 NAME: TAG_Byte + Payload: A single signed byte (8 bits) + + TYPE: 2 NAME: TAG_Short + Payload: A signed short (16 bits, big endian) + + TYPE: 3 NAME: TAG_Int + Payload: A signed short (32 bits, big endian) + + TYPE: 4 NAME: TAG_Long + Payload: A signed long (64 bits, big endian) + + TYPE: 5 NAME: TAG_Float + Payload: A floating point value (32 bits, big endian, IEEE 754-2008, binary32) + + TYPE: 6 NAME: TAG_Double + Payload: A floating point value (64 bits, big endian, IEEE 754-2008, binary64) + + TYPE: 7 NAME: TAG_Byte_Array + Payload: TAG_Int length + An array of bytes of unspecified format. The length of this array is bytes + + TYPE: 8 NAME: TAG_String + Payload: TAG_Short length + An array of bytes defining a string in UTF-8 format. The length of this array is bytes + + TYPE: 9 NAME: TAG_List + Payload: TAG_Byte tagId + TAG_Int length + A sequential list of Tags (not Named Tags), of type . The length of this array is + Tags + Notes: All tags share the same type. + + TYPE: 10 NAME: TAG_Compound + Payload: A sequential list of Named Tags. This array keeps going until a TAG_End is found. + TAG_End end + Notes: If there's a nested TAG_Compound within this tag, that one will also have a TAG_End, so simply reading + until the next TAG_End will not work. + The names of the named tags have to be unique within each TAG_Compound + The order of the tags is not guaranteed. + + + // NEW TAGS (not in the original spec) + TYPE: 11 NAME: TAG_Int_Array + Payload: TAG_Int length + An array of integers. The length of this array is integers + + TYPE: 12 NAME: TAG_Long_Array + Payload: TAG_Int length + An array of longs. The length of this array is longs + + + """ + + END = 0 + BYTE = 1 + SHORT = 2 + INT = 3 + LONG = 4 + FLOAT = 5 + DOUBLE = 6 + BYTE_ARRAY = 7 + STRING = 8 + LIST = 9 + COMPOUND = 10 + INT_ARRAY = 11 + LONG_ARRAY = 12 + + +PayloadType: TypeAlias = Union[ + int, + float, + bytes, + str, + "NBTag", + Sequence["PayloadType"], + Mapping[str, "PayloadType"], +] + + +@runtime_checkable +class NBTagConvertible(Protocol): + """Protocol for objects that can be converted to an NBT tag.""" + + __slots__ = () + + def to_nbt(self, name: str = "") -> NBTag: + """Convert the object to an NBT tag. + + :param name: The name of the tag. + + :return: The NBT tag created from the object. + """ + ... + + +FromObjectType: TypeAlias = Union[ + int, + float, + bytes, + str, + NBTagConvertible, + Sequence["FromObjectType"], + Mapping[str, "FromObjectType"], +] + +FromObjectSchema: TypeAlias = Union[ + Type["NBTag"], + Type[NBTagConvertible], + Sequence["FromObjectSchema"], + Mapping[str, "FromObjectSchema"], +] + + +class NBTag(MCType, NBTagConvertible): + """Base class for NBT tags. + + In MC v1.20.2+ the type and name of the root tag are not written to the buffer, and unless specified, the type of + the tag is assumed to be TAG_Compound. + """ + + __slots__ = ("name", "payload") + + def __init__(self, payload: PayloadType, name: str = ""): + self.name = name + self.payload = payload + + @override + def serialize(self, with_type: bool = True, with_name: bool = True) -> Buffer: + """Serialize the NBT tag to a buffer. + + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + .. note:: These parameters only control the first level of serialization. + + :return: The buffer containing the serialized NBT tag. + """ + buf = Buffer() + self.write_to(buf, with_name=with_name, with_type=with_type) + return buf + + def _write_header(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the header of the NBT tag to the buffer. + + :param buf: The buffer to write to. + :param with_type: Whether to include the type of the tag in the serialization. + :param with_name: Whether to include the name of the tag in the serialization. + """ + if with_type: + tag_type = _get_tag_type(self) + buf.write_value(StructFormat.BYTE, tag_type.value) + if with_name and self.name: + StringNBT(self.name).write_to(buf, with_type=False, with_name=False) + + @abstractmethod + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + """Write the NBT tag to the buffer.""" + ... + + @override + @classmethod + def deserialize(cls, buf: Buffer, with_name: bool = True, with_type: bool = True) -> NBTag: + """Deserialize the NBT tag. + + :param buf: The buffer to read from. + :param with_name: Whether to read the name of the tag. If set to False, the tag will have the name "". + :param with_type: Whether to read the type of the tag. + If set to True, the tag will be read from the buffer and + the return type will be inferred from the tag type. + If set to False and called from a subclass, the tag type will be inferred from the subclass. + If set to False and called from the base class, the tag type will be TAG_Compound. + + If with_type is set to False, the buffer must not start with the tag type byte. + + :return: The deserialized NBT tag. + """ + name, tag_type = cls._read_header(buf, with_name=with_name, read_type=with_type) + + tag_class = ASSOCIATED_TYPES[tag_type] + if cls not in (NBTag, tag_class): + raise TypeError(f"Expected a {cls.__name__} tag, but found a different tag ({tag_class.__name__}).") + + tag = tag_class.read_from(buf, with_type=False, with_name=False) + tag.name = name + return tag + + @classmethod + def _read_header(cls, buf: Buffer, read_type: bool = True, with_name: bool = True) -> tuple[str, NBTagType]: + """Read the header of the NBT tag. + + :param buf: The buffer to read from. + :param read_type: Whether to read the type of the tag from the buffer. + :param with_name: Whether to read the name of the tag. If set to False, the tag will have the name "". + + :return: A tuple containing the name and the tag type. + + + .. note:: It is possible that this function reads nothing from the buffer if both with_name and read_type are + set to False. + """ + if read_type: + try: + tag_type = NBTagType(buf.read_value(StructFormat.BYTE)) + except OSError as exc: + raise IOError("Buffer is empty.") from exc + except ValueError as exc: + raise TypeError("Invalid tag type.") from exc + else: + tag_type = _get_tag_type(cls) + + if tag_type is NBTagType.END: + return "", tag_type + + name = StringNBT.read_from(buf, with_type=False, with_name=False).value if with_name else "" + + return name, tag_type + + @classmethod + @abstractmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> NBTag: + """Read the NBT tag from the buffer. + + :param buf: The buffer to read from. + :param with_type: Whether to read the type of the tag. + If set to True, the tag will be read from the buffer and + the return type will be inferred from the tag type. + If set to False and called from a subclass, the tag type will be inferred from the subclass. + If set to False and called from the base class, the tag type will be TAG_Compound. + :param with_name: Whether to read the name of the tag. If set to False, the tag will have the name "". + + :return: The NBT tag. + """ + ... + + @staticmethod + def from_object(data: FromObjectType, schema: FromObjectSchema, name: str = "") -> NBTag: + """Create an NBT tag from a python object and a schema. + + :param data: The python object to create the NBT tag from. + :param schema: The schema used to create the NBT tags. + :param name: The name of the NBT tag. + + The schema is a description of the types of the data in the python object. + The schema can be a subclass of NBTag (e.g. IntNBT, StringNBT, CompoundNBT, etc.), a dictionary, a list, a + tuple, or an object that has a `to_nbt` method. + + Example of schema: + schema = { + "string": StringNBT, + "list_of_floats": [FloatNBT], + "list_of_compounds": [{ + "key": StringNBT, + "value": IntNBT, + }], + "list_of_lists": [[IntNBT], [StringNBT]], + } + + This would be translated into a CompoundNBT + + :return: The NBT tag created from the python object. + """ + # Case 0 : schema is an object with a `to_nbt` method (could be a subclass of NBTag for all we know, as long + # as the data is an instance of the schema it will work) + if isinstance(schema, type) and hasattr(schema, "to_nbt") and isinstance(data, schema): + return data.to_nbt(name=name) + + # Case 1 : schema is a NBTag subclass + if isinstance(schema, type) and issubclass(schema, NBTag): + if schema in (CompoundNBT, ListNBT): + raise ValueError("Use a list or a dictionary in the schema to create a CompoundNBT or a ListNBT.") + # Check if the data contains the name (if it is a dictionary) + if isinstance(data, dict): + if len(data) != 1: + raise ValueError("Expected a dictionary with a single key-value pair.") + # We also check if the name isn't already set + if name: + raise ValueError("The name is already set.") + key, value = next(iter(data.items())) + # Recursive call to go to the next part + return NBTag.from_object(value, schema, name=key) + # Else we check if the data can be a payload for the tag + if not isinstance(data, (bytes, str, int, float, list)): + raise TypeError(f"Expected a bytes, str, int, float, but found {type(data).__name__}.") + # Check if the data is a list of integers + if isinstance(data, list) and not all(isinstance(item, int) for item in data): + raise TypeError("Expected a list of integers.") + data = cast(Union[bytes, str, int, float, List[int]], data) + # Create the tag with the data and the name + return schema(data, name=name) + + # Sanity check : Verify that all type schemas have been handled + if not isinstance(schema, (list, tuple, dict)): + raise TypeError( + "The schema must be a list, dict, a subclass of NBTag or an object with a `to_nbt` method." + ) + + # Case 2 : schema is a dictionary + if isinstance(schema, dict): + # We can unpack the dictionary and create a CompoundNBT tag + if not isinstance(data, dict): + raise TypeError("Expected a dictionary, but found a different type.") + # Iterate over the dictionary + payload: list[NBTag] = [] + for key, value in data.items(): + # Recursive calls + payload.append(NBTag.from_object(value, schema[key], name=key)) + # Finally we assign the payload and the name to the CompoundNBT tag + return CompoundNBT(payload, name=name) + + # Case 3 : schema is a list or a tuple + # We need to check if every element in the schema has the same type + # but keep in mind that dict and list are also valid types, as long + # as there are only dicts, or only lists in the schema + if not isinstance(data, list): + raise TypeError("Expected a list, but found a different type.") + payload: list[NBTag] = [] + if len(schema) == 1: + # We have two cases here, either the schema supports an unknown number of elements of a single type ... + children_schema = schema[0] + for item in data: + # No name in list items + payload.append(NBTag.from_object(item, children_schema)) + return ListNBT(payload, name=name) + + # ... or the schema is a list of schemas + # Check if the schema and the data have the same length + if len(schema) != len(data): + raise ValueError(f"The schema and the data must have the same length. ({len(schema)=} != {len(data)=})") + if len(schema) == 0: + return ListNBT([], name=name) + # Check that the schema only has one type of elements + first_schema = schema[0] + # Dict/List case + if isinstance(first_schema, (list, dict)) and not all(isinstance(item, type(first_schema)) for item in schema): + raise TypeError(f"Expected a list of lists or dictionaries, but found a different type ({schema=}).") + # NBTag case + if isinstance(first_schema, type) and not all(item == first_schema for item in schema): + raise TypeError(f"The schema must contain a single type of elements. ({schema=})") + + for item, sub_schema in zip(data, schema): + payload.append(NBTag.from_object(item, sub_schema)) + return ListNBT(payload, name=name) + + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> PayloadType | Mapping[str, PayloadType] | tuple[PayloadType | Mapping[str, PayloadType], FromObjectSchema]: + """Convert the NBT tag to a python object. + + :param include_schema: Whether to return a schema describing the types of the original tag. + :param include_name: Whether to include the name of the tag in the output. + If the tag has no name, the name will be set to "". + + :return: Either : + - A python object representing the payload of the tag. (default) + - A dictionary containing the name associated with a python object representing the payload of the tag. + - A tuple which includes one of the above and a schema describing the types of the original tag. + """ + if type(self) is EndNBT: + return NotImplemented + if type(self) in (CompoundNBT, ListNBT): + raise TypeError( + f"Use the `{type(self).__name__}.to_object()` method to convert the tag to a python object." + ) + result = self.payload if not include_name else {self.name: self.payload} + if include_schema: + return result, type(self) + return result + + @override + def __repr__(self) -> str: + if self.name: + return f"{type(self).__name__}[{self.name!r}]({self.payload!r})" + return f"{type(self).__name__}({self.payload!r})" + + @override + def __eq__(self, other: object) -> bool: + """Check equality between two NBT tags.""" + if not isinstance(other, NBTag): + return NotImplemented + if type(self) is not type(other): + return False + return self.name == other.name and self.payload == other.payload + + @override + def to_nbt(self, name: str = "") -> NBTag: + """Convert the object to an NBT tag. + + .. warning:: This is already an NBT tag, so it will modify the name of the tag and return itself. + """ + self.name = name + return self + + @property + @abstractmethod + def value(self) -> PayloadType: + """Get the payload of the NBT tag in a python-friendly format.""" + ... + + +# endregion +# region NBT tags types + + +class EndNBT(NBTag): + """Sentinel tag used to mark the end of a TAG_Compound.""" + + __slots__ = () + + def __init__(self): + """Create a new EndNBT tag.""" + super().__init__(0, name="") + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = False) -> None: + self._write_header(buf, with_type=with_type, with_name=False) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> EndNBT: + _, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + return EndNBT() + + @override + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> PayloadType | Mapping[str, PayloadType]: + return NotImplemented + + @property + @override + def value(self) -> PayloadType: + return NotImplemented + + +class ByteNBT(NBTag): + """NBT tag representing a single byte value, represented as a signed 8-bit integer.""" + + __slots__ = () + payload: int + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + if self.payload < -(1 << 7) or self.payload >= 1 << 7: + raise OverflowError("Byte value out of range.") + + buf.write_value(StructFormat.BYTE, self.payload) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 1: + raise IOError("Buffer does not contain enough data to read a byte. (Empty buffer)") + + return ByteNBT(buf.read_value(StructFormat.BYTE), name=name) + + def __int__(self) -> int: + """Get the integer value of the ByteNBT tag.""" + return self.payload + + @property + @override + def value(self) -> int: + return self.payload + + +class ShortNBT(ByteNBT): + """NBT tag representing a short value, represented as a signed 16-bit integer.""" + + __slots__ = () + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + + if self.payload < -(1 << 15) or self.payload >= 1 << 15: + raise OverflowError("Short value out of range.") + + buf.write_value(StructFormat.SHORT, self.payload) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ShortNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 2: + raise IOError("Buffer does not contain enough data to read a short.") + + return ShortNBT(buf.read_value(StructFormat.SHORT), name=name) + + +class IntNBT(ByteNBT): + """NBT tag representing an integer value, represented as a signed 32-bit integer.""" + + __slots__ = () + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + + if self.payload < -(1 << 31) or self.payload >= 1 << 31: + raise OverflowError("Integer value out of range.") + + # No more messing around with the struct, we want 32 bits of data no matter what + buf.write_value(StructFormat.INT, self.payload) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 4: + raise IOError("Buffer does not contain enough data to read an int.") + + return IntNBT(buf.read_value(StructFormat.INT), name=name) + + +class LongNBT(ByteNBT): + """NBT tag representing a long value, represented as a signed 64-bit integer.""" + + __slots__ = () + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + + if self.payload < -(1 << 63) or self.payload >= 1 << 63: + raise OverflowError("Long value out of range.") + + # No more messing around with the struct, we want 64 bits of data no matter what + buf.write_value(StructFormat.LONGLONG, self.payload) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 8: + raise IOError("Buffer does not contain enough data to read a long.") + + return LongNBT(buf.read_value(StructFormat.LONGLONG), name=name) + + +class FloatNBT(NBTag): + """NBT tag representing a floating-point value, represented as a 32-bit IEEE 754-2008 binary32 value.""" + + payload: float + + __slots__ = () + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + buf.write_value(StructFormat.FLOAT, self.payload) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> FloatNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 4: + raise IOError("Buffer does not contain enough data to read a float.") + + return FloatNBT(buf.read_value(StructFormat.FLOAT), name=name) + + def __float__(self) -> float: + """Get the float value of the FloatNBT tag.""" + return self.payload + + @property + @override + def value(self) -> float: + return self.payload + + +class DoubleNBT(FloatNBT): + """NBT tag representing a double-precision floating-point value, represented as a 64-bit IEEE 754-2008 binary64.""" + + __slots__ = () + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + buf.write_value(StructFormat.DOUBLE, self.payload) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> DoubleNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + + if buf.remaining < 8: + raise IOError("Buffer does not contain enough data to read a double.") + + return DoubleNBT(buf.read_value(StructFormat.DOUBLE), name=name) + + +class ByteArrayNBT(NBTag): + """NBT tag representing an array of bytes. The length of the array is stored as a signed 32-bit integer.""" + + __slots__ = () + + payload: bytes + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + IntNBT(len(self.payload)).write_to(buf, with_type=False, with_name=False) + buf.write(self.payload) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ByteArrayNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + try: + length = IntNBT.read_from(buf, with_type=False, with_name=False).value + except IOError as exc: + raise IOError("Buffer does not contain enough data to read a byte array.") from exc + + if length < 0: + raise ValueError("Invalid byte array length.") + + if buf.remaining < length: + raise IOError( + f"Buffer does not contain enough data to read the byte array ({buf.remaining} < {length} bytes)." + ) + + return ByteArrayNBT(bytes(buf.read(length)), name=name) + + def __bytes__(self) -> bytes: + """Get the bytes value of the ByteArrayNBT tag.""" + return self.payload + + @override + def __repr__(self) -> str: + if self.name: + return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)})" + if len(self.payload) < 8: + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}(length={len(self.payload)}, {bytes(self.payload[:7])!r}...)" + + @property + @override + def value(self) -> bytes: + return self.payload + + +class StringNBT(NBTag): + """NBT tag representing an UTF-8 string value. The length of the string is stored as a signed 16-bit integer.""" + + __slots__ = () + + payload: str + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + if len(self.payload) > 32767: + # Check the length of the string (can't generate strings that long in tests) + raise ValueError("Maximum character limit for writing strings is 32767 characters.") # pragma: no cover + + data = bytes(self.payload, "utf-8") + ShortNBT(len(data)).write_to(buf, with_type=False, with_name=False) + buf.write(data) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> StringNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + try: + length = ShortNBT.read_from(buf, with_type=False, with_name=False).value + except IOError as exc: + raise IOError("Buffer does not contain enough data to read a string.") from exc + + if length < 0: + raise ValueError("Invalid string length.") + + if buf.remaining < length: + raise IOError("Buffer does not contain enough data to read the string.") + data = buf.read(length) + try: + return StringNBT(data.decode("utf-8"), name=name) + except UnicodeDecodeError: + raise # We want to know it + + @override + def __str__(self) -> str: + return self.payload + + @property + @override + def value(self) -> str: + return self.payload + + +class ListNBT(NBTag): + """NBT tag representing a list of tags. All tags in the list must be of the same type.""" + + __slots__ = () + + payload: list[NBTag] + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + + if not self.payload: + # Set the tag type to TAG_End if the list is empty + EndNBT().write_to(buf, with_name=False) + IntNBT(0).write_to(buf, with_name=False, with_type=False) + return + + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway + raise ValueError( + f"All items in a list must be NBTags. Got {self.payload!r}.\nUse NBTag.from_object() to convert " + "objects to tags first." + ) + + tag_type = _get_tag_type(self.payload[0]) + ByteNBT(tag_type).write_to(buf, with_name=False, with_type=False) + IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + for tag in self.payload: + if tag_type != _get_tag_type(tag): + raise ValueError(f"All tags in a list must be of the same type, got tag {tag!r}") + if tag.name != "": + raise ValueError(f"All tags in a list must be unnamed, got tag {tag!r}") + + tag.write_to(buf, with_type=False, with_name=False) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> ListNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + list_tag_type = ByteNBT.read_from(buf, with_type=False, with_name=False).payload + try: + length = IntNBT.read_from(buf, with_type=False, with_name=False).value + except IOError as exc: + raise IOError("Buffer does not contain enough data to read a list.") from exc + + if length < 1 or list_tag_type is NBTagType.END: + return ListNBT([], name=name) + + try: + list_tag_type = NBTagType(list_tag_type) + except ValueError as exc: + raise TypeError(f"Unknown tag type {list_tag_type}.") from exc + + list_type_class = ASSOCIATED_TYPES.get(list_tag_type, NBTag) + if list_type_class is NBTag: + raise TypeError(f"Unknown tag type {list_tag_type}.") # pragma: no cover + try: + payload = [list_type_class.read_from(buf, with_type=False, with_name=False) for _ in range(length)] + except IOError as exc: + raise IOError("Buffer does not contain enough data to read the list.") from exc + return ListNBT(payload, name=name) + + def __iter__(self) -> Iterator[NBTag]: + """Iterate over the tags in the list.""" + yield from self.payload + + @override + def __repr__(self) -> str: + if self.name: + return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + if len(self.payload) < 8: + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + + @override + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> ( + list[PayloadType] + | Mapping[str, list[PayloadType]] + | tuple[list[PayloadType] | Mapping[str, list[PayloadType]], list[FromObjectSchema]] + ): + result = [tag.to_object() for tag in self.payload] + result = cast(List[PayloadType], result) + result = result if not include_name else {self.name: result} + if include_schema: + subschemas = [ + cast( + Tuple[PayloadType, FromObjectSchema], + tag.to_object(include_schema=True), + )[1] + for tag in self.payload + ] + if len(result) == 0: + return result, [] + + first = subschemas[0] + if all(schema == first for schema in subschemas): + return result, [first] + + if not isinstance(first, (dict, list)): + raise TypeError(f"The schema must contain either a dict or a list. Found {first!r}") + # This will take care of ensuring either everything is a dict or a list + if not all(isinstance(schema, type(first)) for schema in subschemas): + raise TypeError(f"All items in the list must have the same type. Found {subschemas!r}") + return result, subschemas + return result + + @property + @override + def value(self) -> list[PayloadType]: + return [tag.value for tag in self.payload] + + +class CompoundNBT(NBTag): + """NBT tag representing a compound of named tags.""" + + __slots__ = () + + payload: list[NBTag] + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + if not self.payload: + EndNBT().write_to(buf, with_name=False, with_type=True) + return + if not all(isinstance(tag, NBTag) for tag in self.payload): # type: ignore # We want to check anyway + raise ValueError( + f"All items in a compound must be NBTags. Got {self.payload!r}.\n" + "Use NBTag.from_object() to convert objects to tags first." + ) + + if not all(tag.name for tag in self.payload): + raise ValueError(f"All tags in a compound must be named, got tags {self.payload!r}") + + if len(self.payload) != len({tag.name for tag in self.payload}): # Check for duplicate names + raise ValueError("All tags in a compound must have unique names.") + + for tag in self.payload: + tag.write_to(buf) + EndNBT().write_to(buf, with_name=False, with_type=True) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> CompoundNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if _get_tag_type(cls) != tag_type: + raise TypeError(f"Expected a {_get_tag_type(cls).name} tag, but found a different tag ({tag_type.name}).") + + payload: list[NBTag] = [] + while True: + child_name, child_type = cls._read_header(buf, with_name=True, read_type=True) + if child_type is NBTagType.END: + break + # The name and type of the tag have already been read + tag = ASSOCIATED_TYPES[child_type].read_from(buf, with_type=False, with_name=False) + tag.name = child_name + payload.append(tag) + return CompoundNBT(payload, name=name) + + def __iter__(self): + """Iterate over the tags in the compound.""" + for tag in self.payload: + yield tag.name, tag + + @override + def __repr__(self) -> str: + if self.name: + return f"{type(self).__name__}[{self.name!r}]({dict(self)})" + return f"{type(self).__name__}({dict(self)})" + + @override + def to_object( + self, include_schema: bool = False, include_name: bool = False + ) -> ( + Mapping[str, PayloadType] + | Mapping[str, Mapping[str, PayloadType]] + | tuple[ + Mapping[str, PayloadType] | Mapping[str, Mapping[str, PayloadType]], + Mapping[str, FromObjectSchema], + ] + ): + result = {tag.name: tag.to_object() for tag in self.payload} + result = cast(Mapping[str, PayloadType], result) + result = result if not include_name else {self.name: result} + if include_schema: + subschemas = { + tag.name: cast( + Tuple[PayloadType, FromObjectSchema], + tag.to_object(include_schema=True), + )[1] + for tag in self.payload + } + return result, subschemas + return result + + @override + def __eq__(self, other: object) -> bool: + """Check equality between two CompoundNBT tags. + + :param other: The other CompoundNBT tag to compare to. + + :return: True if the tags are equal, False otherwise. + + .. note:: The order of the tags is not guaranteed, but the names of the tags must match. This function assumes + that there are no duplicate tags in the compound. + """ + # The order of the tags is not guaranteed + if not isinstance(other, NBTag): + return NotImplemented + if type(self) is not type(other): + return False + if self.name != other.name: + return False + other = cast(CompoundNBT, other) + if len(self.payload) != len(other.payload): + return False + return all(tag in other.payload for tag in self.payload) + + @property + @override + def value(self) -> dict[str, PayloadType]: + return {tag.name: tag.value for tag in self.payload} + + +class IntArrayNBT(NBTag): + """NBT tag representing an array of integers. The length of the array is stored as a signed 32-bit integer.""" + + __slots__ = () + + payload: list[int] + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + + if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway + raise ValueError("All items in an integer array must be integers.") + + if any(item < -(1 << 31) or item >= 1 << 31 for item in self.payload): + raise OverflowError("Integer array contains values out of range.") + + IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + for i in self.payload: + IntNBT(i).write_to(buf, with_name=False, with_type=False) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> IntArrayNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != NBTagType.INT_ARRAY: + raise TypeError(f"Expected an INT_ARRAY tag, but found a different tag ({tag_type}).") + length = IntNBT.read_from(buf, with_type=False, with_name=False).value + try: + payload = [IntNBT.read_from(buf, with_type is NBTagType.INT, with_name=False).value for _ in range(length)] + except IOError as exc: + raise IOError( + "Buffer does not contain enough data to read the entire integer array. (Incomplete data)" + ) from exc + return IntArrayNBT(payload, name=name) + + @override + def __repr__(self) -> str: + if self.name: + return f"{type(self).__name__}[{self.name!r}](length={len(self.payload)}, {self.payload!r})" + if len(self.payload) < 8: + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload!r})" + return f"{type(self).__name__}(length={len(self.payload)}, {self.payload[:7]!r}...)" + + def __iter__(self) -> Iterator[int]: + """Iterate over the integers in the array.""" + yield from self.payload + + @property + @override + def value(self) -> list[int]: + return self.payload + + +class LongArrayNBT(IntArrayNBT): + """NBT tag representing an array of longs. The length of the array is stored as a signed 32-bit integer.""" + + __slots__ = () + + @override + def write_to(self, buf: Buffer, with_type: bool = True, with_name: bool = True) -> None: + self._write_header(buf, with_type=with_type, with_name=with_name) + + if any(not isinstance(item, int) for item in self.payload): # type: ignore # We want to check anyway + raise ValueError(f"All items in a long array must be integers. ({self.payload})") + + if any(item < -(1 << 63) or item >= 1 << 63 for item in self.payload): + raise OverflowError(f"Long array contains values out of range. ({self.payload})") + + IntNBT(len(self.payload)).write_to(buf, with_name=False, with_type=False) + for i in self.payload: + LongNBT(i).write_to(buf, with_name=False, with_type=False) + + @override + @classmethod + def read_from(cls, buf: Buffer, with_type: bool = True, with_name: bool = True) -> LongArrayNBT: + name, tag_type = cls._read_header(buf, read_type=with_type, with_name=with_name) + if tag_type != NBTagType.LONG_ARRAY: + raise TypeError(f"Expected a LONG_ARRAY tag, but found a different tag ({tag_type}).") + length = IntNBT.read_from(buf, with_type=False, with_name=False).payload + + try: + payload = [LongNBT.read_from(buf, with_type=False, with_name=False).payload for _ in range(length)] + except IOError as exc: + raise IOError( + "Buffer does not contain enough data to read the entire long array. (Incomplete data)" + ) from exc + return LongArrayNBT(payload, name=name) + + +# endregion + +# region: NBT Associated Types +ASSOCIATED_TYPES: dict[NBTagType, type[NBTag]] = { + NBTagType.END: EndNBT, + NBTagType.BYTE: ByteNBT, + NBTagType.SHORT: ShortNBT, + NBTagType.INT: IntNBT, + NBTagType.LONG: LongNBT, + NBTagType.FLOAT: FloatNBT, + NBTagType.DOUBLE: DoubleNBT, + NBTagType.BYTE_ARRAY: ByteArrayNBT, + NBTagType.STRING: StringNBT, + NBTagType.LIST: ListNBT, + NBTagType.COMPOUND: CompoundNBT, + NBTagType.INT_ARRAY: IntArrayNBT, + NBTagType.LONG_ARRAY: LongArrayNBT, +} + + +def _get_tag_type(tag: NBTag | type[NBTag]) -> NBTagType: + """Get the tag type of an NBTag object or class. + + :param tag: The tag to get the type of. + + :return: The tag type of the tag. + """ + cls = tag if isinstance(tag, type) else type(tag) + + if cls is NBTag: + return NBTagType.COMPOUND + for tag_type, tag_cls in ASSOCIATED_TYPES.items(): + if cls is tag_cls: + return tag_type + + raise ValueError(f"Unknown tag type {cls}.") # pragma: no cover + + +# endregion diff --git a/tests/mcproto/types/test_nbt.py b/tests/mcproto/types/test_nbt.py new file mode 100644 index 00000000..4ac0191c --- /dev/null +++ b/tests/mcproto/types/test_nbt.py @@ -0,0 +1,1396 @@ +from __future__ import annotations + +import struct +from typing import Any, Dict, List, cast + +import pytest + +from mcproto.buffer import Buffer +from mcproto.types.nbt import ( + ByteArrayNBT, + ByteNBT, + CompoundNBT, + DoubleNBT, + EndNBT, + FloatNBT, + IntArrayNBT, + IntNBT, + ListNBT, + LongArrayNBT, + LongNBT, + NBTag, + NBTagType, + PayloadType, + ShortNBT, + StringNBT, +) + +# region EndNBT + + +def test_serialize_deserialize_end(): + """Test serialization/deserialization of NBT END tag.""" + output_bytes = EndNBT().serialize() + assert output_bytes == bytearray.fromhex("00") + + buffer = Buffer() + EndNBT().write_to(buffer) + assert buffer == bytearray.fromhex("00") + + buffer.clear() + EndNBT().write_to(buffer, with_name=False) + assert buffer == bytearray.fromhex("00") + + buffer = Buffer(bytearray.fromhex("00")) + assert EndNBT.deserialize(buffer) == EndNBT() + + +# endregion +# region Numerical NBT tests + + +@pytest.mark.parametrize( + ("nbt_class", "value", "expected_bytes"), + [ + (ByteNBT, 0, bytearray.fromhex("01 00")), + (ByteNBT, 1, bytearray.fromhex("01 01")), + (ByteNBT, 127, bytearray.fromhex("01 7F")), + (ByteNBT, -128, bytearray.fromhex("01 80")), + (ByteNBT, -1, bytearray.fromhex("01 FF")), + (ByteNBT, 12, bytearray.fromhex("01 0C")), + (ShortNBT, 0, bytearray.fromhex("02 00 00")), + (ShortNBT, 1, bytearray.fromhex("02 00 01")), + (ShortNBT, 32767, bytearray.fromhex("02 7F FF")), + (ShortNBT, -32768, bytearray.fromhex("02 80 00")), + (ShortNBT, -1, bytearray.fromhex("02 FF FF")), + (ShortNBT, 12, bytearray.fromhex("02 00 0C")), + (IntNBT, 0, bytearray.fromhex("03 00 00 00 00")), + (IntNBT, 1, bytearray.fromhex("03 00 00 00 01")), + (IntNBT, 2147483647, bytearray.fromhex("03 7F FF FF FF")), + (IntNBT, -2147483648, bytearray.fromhex("03 80 00 00 00")), + (IntNBT, -1, bytearray.fromhex("03 FF FF FF FF")), + (IntNBT, 12, bytearray.fromhex("03 00 00 00 0C")), + (LongNBT, 0, bytearray.fromhex("04 00 00 00 00 00 00 00 00")), + (LongNBT, 1, bytearray.fromhex("04 00 00 00 00 00 00 00 01")), + (LongNBT, (1 << 63) - 1, bytearray.fromhex("04 7F FF FF FF FF FF FF FF")), + (LongNBT, -(1 << 63), bytearray.fromhex("04 80 00 00 00 00 00 00 00")), + (LongNBT, -1, bytearray.fromhex("04 FF FF FF FF FF FF FF FF")), + (LongNBT, 12, bytearray.fromhex("04 00 00 00 00 00 00 00 0C")), + (FloatNBT, 1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 1.0))), + (FloatNBT, 0.25, bytearray.fromhex("05") + bytes(struct.pack(">f", 0.25))), + (FloatNBT, -1.0, bytearray.fromhex("05") + bytes(struct.pack(">f", -1.0))), + (FloatNBT, 12.0, bytearray.fromhex("05") + bytes(struct.pack(">f", 12.0))), + (DoubleNBT, 1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 1.0))), + (DoubleNBT, 0.25, bytearray.fromhex("06") + bytes(struct.pack(">d", 0.25))), + (DoubleNBT, -1.0, bytearray.fromhex("06") + bytes(struct.pack(">d", -1.0))), + (DoubleNBT, 12.0, bytearray.fromhex("06") + bytes(struct.pack(">d", 12.0))), + (ByteArrayNBT, b"", bytearray.fromhex("07 00 00 00 00")), + (ByteArrayNBT, b"\x00", bytearray.fromhex("07 00 00 00 01") + b"\x00"), + (ByteArrayNBT, b"\x00\x01", bytearray.fromhex("07 00 00 00 02") + b"\x00\x01"), + ( + ByteArrayNBT, + b"\x00\x01\x02", + bytearray.fromhex("07 00 00 00 03") + b"\x00\x01\x02", + ), + ( + ByteArrayNBT, + b"\x00\x01\x02\x03", + bytearray.fromhex("07 00 00 00 04") + b"\x00\x01\x02\x03", + ), + ( + ByteArrayNBT, + b"\xff" * 1024, + bytearray.fromhex("07 00 00 04 00") + b"\xff" * 1024, + ), + ( + ByteArrayNBT, + bytes((n - 1) * n * 2 % 256 for n in range(256)), + bytearray.fromhex("07 00 00 01 00") + bytes((n - 1) * n * 2 % 256 for n in range(256)), + ), + (StringNBT, "", bytearray.fromhex("08 00 00")), + (StringNBT, "test", bytearray.fromhex("08 00 04") + b"test"), + (StringNBT, "a" * 100, bytearray.fromhex("08 00 64") + b"a" * (100)), + (StringNBT, "&à@é", bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8")), + (ListNBT, [], bytearray.fromhex("09 00 00 00 00 00")), + (ListNBT, [ByteNBT(0)], bytearray.fromhex("09 01 00 00 00 01 00")), + ( + ListNBT, + [ShortNBT(127), ShortNBT(256)], + bytearray.fromhex("09 02 00 00 00 02 00 7F 01 00"), + ), + ( + ListNBT, + [ListNBT([ByteNBT(0)]), ListNBT([IntNBT(256)])], + bytearray.fromhex("09 09 00 00 00 02 01 00 00 00 01 00 03 00 00 00 01 00 00 01 00"), + ), + (CompoundNBT, [], bytearray.fromhex("0A 00")), + ( + CompoundNBT, + [ByteNBT(0, name="test")], + bytearray.fromhex("0A") + ByteNBT(0, name="test").serialize() + b"\x00", + ), + ( + CompoundNBT, + [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], + bytearray.fromhex("0A") + ShortNBT(128, "Short").serialize() + ByteNBT(-1, "Byte").serialize() + b"\x00", + ), + ( + CompoundNBT, + [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], + bytearray.fromhex("0A") + CompoundNBT([ByteNBT(0, name="Byte")], name="test").serialize() + b"\x00", + ), + ( + CompoundNBT, + [ + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test"), + IntNBT(-1, "Int 2"), + ], + bytearray.fromhex("0A") + + CompoundNBT([ByteNBT(0, name="Byte"), IntNBT(0, name="Int")], "test").serialize() + + IntNBT(-1, "Int 2").serialize() + + b"\x00", + ), + (IntArrayNBT, [], bytearray.fromhex("0B 00 00 00 00")), + (IntArrayNBT, [0], bytearray.fromhex("0B 00 00 00 01 00 00 00 00")), + ( + IntArrayNBT, + [0, 1], + bytearray.fromhex("0B 00 00 00 02 00 00 00 00 00 00 00 01"), + ), + ( + IntArrayNBT, + [1, 2, 3], + bytearray.fromhex("0B 00 00 00 03 00 00 00 01 00 00 00 02 00 00 00 03"), + ), + (IntArrayNBT, [(1 << 31) - 1], bytearray.fromhex("0B 00 00 00 01 7F FF FF FF")), + ( + IntArrayNBT, + [(1 << 31) - 1, (1 << 31) - 2], + bytearray.fromhex("0B 00 00 00 02 7F FF FF FF 7F FF FF FE"), + ), + ( + IntArrayNBT, + [-1, -2, -3], + bytearray.fromhex("0B 00 00 00 03 FF FF FF FF FF FF FF FE FF FF FF FD"), + ), + ( + IntArrayNBT, + [12] * 1024, + bytearray.fromhex("0B 00 00 04 00") + b"\x00\x00\x00\x0c" * 1024, + ), + (LongArrayNBT, [], bytearray.fromhex("0C 00 00 00 00")), + ( + LongArrayNBT, + [0], + bytearray.fromhex("0C 00 00 00 01 00 00 00 00 00 00 00 00"), + ), + ( + LongArrayNBT, + [0, 1], + bytearray.fromhex("0C 00 00 00 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 01"), + ), + ( + LongArrayNBT, + [1, 2, 3], + bytearray.fromhex( + "0C 00 00 00 03 00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03" + ), + ), + ( + LongArrayNBT, + [(1 << 63) - 1], + bytearray.fromhex("0C 00 00 00 01 7F FF FF FF FF FF FF FF"), + ), + ( + LongArrayNBT, + [(1 << 63) - 1, (1 << 63) - 2], + bytearray.fromhex("0C 00 00 00 02 7F FF FF FF FF FF FF FF 7F FF FF FF FF FF FF FE"), + ), + ( + LongArrayNBT, + [-1, -2, -3], + bytearray.fromhex( + "0C 00 00 00 03 FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF FE FF FF FF FF FF FF FF FD" + ), + ), + ( + LongArrayNBT, + [12] * 1024, + bytearray.fromhex("0C 00 00 04 00") + b"\x00\x00\x00\x00\x00\x00\x00\x0c" * 1024, + ), + ], +) +def test_serialize_deserialize_noname(nbt_class: type[NBTag], value: PayloadType, expected_bytes: bytes): + """Test serialization/deserialization of NBT tag without name.""" + # Test serialization + output_bytes = nbt_class(value).serialize(with_name=False) + output_bytes_no_type = nbt_class(value).serialize(with_type=False, with_name=False) + assert output_bytes == expected_bytes + assert output_bytes_no_type == expected_bytes[1:] + + buffer = Buffer() + nbt_class(value).write_to(buffer, with_name=False) + assert buffer == expected_bytes + + # Test deserialization + buffer = Buffer(expected_bytes) + assert NBTag.deserialize(buffer, with_name=False) == nbt_class(value) + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.deserialize(buffer, with_type=False, with_name=False) == nbt_class(value) + + buffer = Buffer(expected_bytes) + assert nbt_class.read_from(buffer, with_name=False) == nbt_class(value) + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.read_from(buffer, with_type=False, with_name=False) == nbt_class(value) + + +@pytest.mark.parametrize( + ("nbt_class", "value", "name", "expected_bytes"), + [ + ( + ByteNBT, + 0, + "test", + bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("00"), + ), + ( + ByteNBT, + 1, + "a", + bytearray.fromhex("01") + b"\x00\x01a" + bytearray.fromhex("01"), + ), + ( + ByteNBT, + 127, + "&à@é", + bytearray.fromhex("01 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F"), + ), + ( + ByteNBT, + -128, + "test", + bytearray.fromhex("01") + b"\x00\x04test" + bytearray.fromhex("80"), + ), + ( + ByteNBT, + 12, + "a" * 100, + bytearray.fromhex("01") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("0C"), + ), + ( + ShortNBT, + 0, + "test", + bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("00 00"), + ), + ( + ShortNBT, + 1, + "a", + bytearray.fromhex("02") + b"\x00\x01a" + bytearray.fromhex("00 01"), + ), + ( + ShortNBT, + 32767, + "&à@é", + bytearray.fromhex("02 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF"), + ), + ( + ShortNBT, + -32768, + "test", + bytearray.fromhex("02") + b"\x00\x04test" + bytearray.fromhex("80 00"), + ), + ( + ShortNBT, + 12, + "a" * 100, + bytearray.fromhex("02") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 0C"), + ), + ( + IntNBT, + 0, + "test", + bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), + ( + IntNBT, + 1, + "a", + bytearray.fromhex("03") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01"), + ), + ( + IntNBT, + 2147483647, + "&à@é", + bytearray.fromhex("03 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF"), + ), + ( + IntNBT, + -2147483648, + "test", + bytearray.fromhex("03") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00"), + ), + ( + IntNBT, + 12, + "a" * 100, + bytearray.fromhex("03") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 0C"), + ), + ( + LongNBT, + 0, + "test", + bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00 00 00 00"), + ), + ( + LongNBT, + 1, + "a", + bytearray.fromhex("04") + b"\x00\x01a" + bytearray.fromhex("00 00 00 00 00 00 00 01"), + ), + ( + LongNBT, + (1 << 63) - 1, + "&à@é", + bytearray.fromhex("04 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("7F FF FF FF FF FF FF FF"), + ), + ( + LongNBT, + -1 << 63, + "test", + bytearray.fromhex("04") + b"\x00\x04test" + bytearray.fromhex("80 00 00 00 00 00 00 00"), + ), + ( + LongNBT, + 12, + "a" * 100, + bytearray.fromhex("04") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 00 00 00 00 00 0C"), + ), + ( + FloatNBT, + 1.0, + "test", + bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 1.0)), + ), + ( + FloatNBT, + 0.25, + "a", + bytearray.fromhex("05") + b"\x00\x01a" + bytes(struct.pack(">f", 0.25)), + ), + ( + FloatNBT, + -1.0, + "&à@é", + bytearray.fromhex("05 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0)), + ), + ( + FloatNBT, + 12.0, + "test", + bytearray.fromhex("05") + b"\x00\x04test" + bytes(struct.pack(">f", 12.0)), + ), + ( + DoubleNBT, + 1.0, + "test", + bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 1.0)), + ), + ( + DoubleNBT, + 0.25, + "a", + bytearray.fromhex("06") + b"\x00\x01a" + bytes(struct.pack(">d", 0.25)), + ), + ( + DoubleNBT, + -1.0, + "&à@é", + bytearray.fromhex("06 00 06") + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0)), + ), + ( + DoubleNBT, + 12.0, + "test", + bytearray.fromhex("06") + b"\x00\x04test" + bytes(struct.pack(">d", 12.0)), + ), + ( + ByteArrayNBT, + b"", + "test", + bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), + ( + ByteArrayNBT, + b"\x00", + "a", + bytearray.fromhex("07") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00", + ), + ( + ByteArrayNBT, + b"\x00\x01", + "&à@é", + bytearray.fromhex("07 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("00 00 00 02") + b"\x00\x01", + ), + ( + ByteArrayNBT, + b"\x00\x01\x02", + "test", + bytearray.fromhex("07") + b"\x00\x04test" + bytearray.fromhex("00 00 00 03") + b"\x00\x01\x02", + ), + ( + ByteArrayNBT, + b"\xff" * 1024, + "a" * 100, + bytearray.fromhex("07") + b"\x00\x64" + b"a" * 100 + bytearray.fromhex("00 00 04 00") + b"\xff" * 1024, + ), + ( + StringNBT, + "", + "test", + bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 00"), + ), + ( + StringNBT, + "test", + "a", + bytearray.fromhex("08") + b"\x00\x01a" + bytearray.fromhex("00 04") + b"test", + ), + ( + StringNBT, + "a" * 100, + "&à@é", + bytearray.fromhex("08 00 06") + bytes("&à@é", "utf-8") + bytearray.fromhex("00 64") + b"a" * 100, + ), + ( + StringNBT, + "&à@é", + "test", + bytearray.fromhex("08") + b"\x00\x04test" + bytearray.fromhex("00 06") + bytes("&à@é", "utf-8"), + ), + ( + ListNBT, + [], + "test", + bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00 00"), + ), + ( + ListNBT, + [ByteNBT(-1)], + "a", + bytearray.fromhex("09") + b"\x00\x01a" + bytearray.fromhex("01 00 00 00 01 FF"), + ), + ( + ListNBT, + [ShortNBT(127), ShortNBT(256)], + "test", + bytearray.fromhex("09") + b"\x00\x04test" + bytearray.fromhex("02 00 00 00 02 00 7F 01 00"), + ), + ( + ListNBT, + [ListNBT([ByteNBT(-1)]), ListNBT([IntNBT(256)])], + "a", + bytearray.fromhex("09") + + b"\x00\x01a" + + bytearray.fromhex("09 00 00 00 02 01 00 00 00 01 FF 03 00 00 00 01 00 00 01 00"), + ), + ( + CompoundNBT, + [], + "test", + bytearray.fromhex("0A") + b"\x00\x04test" + bytearray.fromhex("00"), + ), + ( + CompoundNBT, + [ByteNBT(0, name="Byte")], + "test", + bytearray.fromhex("0A") + b"\x00\x04test" + ByteNBT(0, name="Byte").serialize() + b"\x00", + ), + ( + CompoundNBT, + [ShortNBT(128, "Short"), ByteNBT(-1, "Byte")], + "test", + bytearray.fromhex("0A") + + b"\x00\x04test" + + ShortNBT(128, "Short").serialize() + + ByteNBT(-1, "Byte").serialize() + + b"\x00", + ), + ( + CompoundNBT, + [CompoundNBT([ByteNBT(0, name="Byte")], name="test")], + "test", + bytearray.fromhex("0A") + + b"\x00\x04test" + + CompoundNBT([ByteNBT(0, name="Byte")], "test").serialize() + + b"\x00", + ), + ( + CompoundNBT, + [ListNBT([ByteNBT(0)], name="List")], + "test", + bytearray.fromhex("0A") + b"\x00\x04test" + ListNBT([ByteNBT(0)], name="List").serialize() + b"\x00", + ), + ( + IntArrayNBT, + [], + "test", + bytearray.fromhex("0B") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), + ( + IntArrayNBT, + [0], + "a", + bytearray.fromhex("0B") + b"\x00\x01a" + bytearray.fromhex("00 00 00 01") + b"\x00\x00\x00\x00", + ), + ( + IntArrayNBT, + [0, 1], + "&à@é", + bytearray.fromhex("0B 00 06") + + bytes("&à@é", "utf-8") + + bytearray.fromhex("00 00 00 02") + + b"\x00\x00\x00\x00\x00\x00\x00\x01", + ), + ( + IntArrayNBT, + [1, 2, 3], + "test", + bytearray.fromhex("0B") + + b"\x00\x04test" + + bytearray.fromhex("00 00 00 03") + + b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03", + ), + ( + IntArrayNBT, + [(1 << 31) - 1], + "a" * 100, + bytearray.fromhex("0B") + + b"\x00\x64" + + b"a" * 100 + + bytearray.fromhex("00 00 00 01") + + b"\x7f\xff\xff\xff", + ), + ( + LongArrayNBT, + [], + "test", + bytearray.fromhex("0C") + b"\x00\x04test" + bytearray.fromhex("00 00 00 00"), + ), + ( + LongArrayNBT, + [0], + "a", + bytearray.fromhex("0C") + + b"\x00\x01a" + + bytearray.fromhex("00 00 00 01") + + b"\x00\x00\x00\x00\x00\x00\x00\x00", + ), + ( + LongArrayNBT, + [0, 1], + "&à@é", + bytearray.fromhex("0C 00 06") + + bytes("&à@é", "utf-8") + + bytearray.fromhex("00 00 00 02") + + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + ), + ( + LongArrayNBT, + [1, 2, 3], + "test", + bytearray.fromhex("0C") + + b"\x00\x04test" + + bytearray.fromhex("00 00 00 03") + + bytearray.fromhex("00 00 00 00 00 00 00 01 00 00 00 00 00 00 00 02 00 00 00 00 00 00 00 03"), + ), + ( + LongArrayNBT, + [(1 << 63) - 1] * 100, + "a" * 100, + bytearray.fromhex("0C") + + b"\x00\x64" + + b"a" * 100 + + bytearray.fromhex("00 00 00 64") + + b"\x7f\xff\xff\xff\xff\xff\xff\xff" * 100, + ), + ], +) +def test_serialize_deserialize(nbt_class: type[NBTag], value: PayloadType, name: str, expected_bytes: bytes): + """Test serialization/deserialization of NBT tag with name.""" + # Test serialization + output_bytes = nbt_class(value, name).serialize() + output_bytes_no_type = nbt_class(value, name).serialize(with_type=False) + assert output_bytes == expected_bytes + assert output_bytes_no_type == expected_bytes[1:] + + buffer = Buffer() + nbt_class(value, name).write_to(buffer) + assert buffer == expected_bytes + + # Test deserialization + buffer = Buffer(expected_bytes * 2) + assert buffer.remaining == len(expected_bytes) * 2 + assert NBTag.deserialize(buffer) == nbt_class(value, name=name) + assert buffer.remaining == len(expected_bytes) + assert NBTag.deserialize(buffer) == nbt_class(value, name=name) + assert buffer.remaining == 0 + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.deserialize(buffer, with_type=False) == nbt_class(value, name=name) + + buffer = Buffer(expected_bytes) + assert nbt_class.read_from(buffer) == nbt_class(value, name=name) + + buffer = Buffer(expected_bytes[1:]) + assert nbt_class.read_from(buffer, with_type=False) == nbt_class(value, name=name) + + +@pytest.mark.parametrize( + ("nbt_class", "size", "tag"), + [ + (ByteNBT, 8, NBTagType.BYTE), + (ShortNBT, 16, NBTagType.SHORT), + (IntNBT, 32, NBTagType.INT), + (LongNBT, 64, NBTagType.LONG), + ], +) +def test_serialize_deserialize_numerical_fail(nbt_class: type[NBTag], size: int, tag: NBTagType): + """Test serialization/deserialization of NBT NUM tag with invalid value.""" + # Out of bounds + with pytest.raises(OverflowError): + nbt_class(1 << (size - 1)).serialize(with_name=False) + + with pytest.raises(OverflowError): + nbt_class(-(1 << (size - 1)) - 1).serialize(with_name=False) + + # Deserialization + buffer = Buffer(bytearray([tag.value + 1] + [0] * (size // 8))) + with pytest.raises(TypeError): # Tries to read a nbt_class, but it's one higher + nbt_class.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([tag.value] + [0] * ((size // 8) - 1))) + with pytest.raises(IOError): + nbt_class.read_from(buffer, with_name=False) + + buffer = Buffer(bytearray([tag.value, 0, 0] + [0] * (size // 8))) + assert nbt_class.read_from(buffer, with_name=True) == nbt_class(0) + + +# endregion + +# region FloatNBT + + +def test_serialize_deserialize_float_fail(): + """Test serialization/deserialization of NBT FLOAT tag with invalid value.""" + with pytest.raises(struct.error): + FloatNBT("test").serialize(with_name=False) + + with pytest.raises(OverflowError): + FloatNBT(1e39, "test").serialize() + + with pytest.raises(OverflowError): + FloatNBT(-1e39, "test").serialize() + + # Deserialization + buffer = Buffer(bytearray([NBTagType.BYTE] + [0] * 4)) + with pytest.raises(TypeError): # Tries to read a FloatNBT, but it's a ByteNBT + FloatNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.FLOAT, 0, 0, 0])) + with pytest.raises(IOError): + FloatNBT.read_from(buffer, with_name=False) + + +# endregion +# region DoubleNBT + + +def test_serialize_deserialize_double_fail(): + """Test serialization/deserialization of NBT DOUBLE tag with invalid value.""" + with pytest.raises(struct.error): + DoubleNBT("test").serialize(with_name=False) + + # Deserialization + buffer = Buffer(bytearray([0x01] + [0] * 8)) + with pytest.raises(TypeError): # Tries to read a DoubleNBT, but it's a ByteNBT + DoubleNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.DOUBLE, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(IOError): + DoubleNBT.read_from(buffer, with_name=False) + + +# endregion +# region ByteArrayNBT + + +def test_serialize_deserialize_bytearray_fail(): + """Test serialization/deserialization of NBT BYTEARRAY tag with invalid value.""" + # Deserialization + buffer = Buffer(bytearray([0x01] + [0] * 4)) + with pytest.raises(TypeError): # Tries to read a ByteArrayNBT, but it's a ByteNBT + ByteArrayNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0])) # Missing length bytes + with pytest.raises(IOError): + ByteArrayNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 1])) # Missing data bytes + with pytest.raises(IOError): + ByteArrayNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0, 0, 0, 2, 0])) # Missing data bytes + with pytest.raises(IOError): + ByteArrayNBT.read_from(buffer, with_name=False) + + # Negative length + buffer = Buffer(bytearray([NBTagType.BYTE_ARRAY, 0xFF, 0xFF, 0xFF, 0xFF])) # length = -1 + with pytest.raises(ValueError): + ByteArrayNBT.deserialize(buffer, with_name=False) + + +# endregion +# region StringNBT + + +def test_serialize_deserialize_string_fail(): + """Test serialization/deserialization of NBT STRING tag with invalid value.""" + # Deserialization + buffer = Buffer(bytearray([0x01, 0, 0])) + with pytest.raises(TypeError): # Tries to read a StringNBT, but it's a ByteNBT + StringNBT.deserialize(buffer, with_name=False) + + # Not enough data for the length + buffer = Buffer(bytearray([NBTagType.STRING, 0])) + with pytest.raises(IOError): + StringNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.STRING, 0, 1])) + with pytest.raises(IOError): + StringNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.STRING, 0, 2, 0])) + with pytest.raises(IOError): + StringNBT.read_from(buffer, with_name=False) + + # Negative length + buffer = Buffer(bytearray([NBTagType.STRING, 0xFF, 0xFF])) # length = -1 + with pytest.raises(ValueError): + StringNBT.deserialize(buffer, with_name=False) + + # Invalid UTF-8 + buffer = Buffer(bytearray([NBTagType.STRING, 0, 1, 0xC0, 0x80])) + with pytest.raises(UnicodeDecodeError): + StringNBT.read_from(buffer, with_name=False) + + +# endregion +# region ListNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([ByteNBT(0), IntNBT(0)], ValueError), + ([ByteNBT(0), "test"], ValueError), + ([ByteNBT(0), None], ValueError), + ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags + ([ByteNBT(128), ByteNBT(-1)], OverflowError), # Check for error propagation + ], +) +def test_serialize_list_fail(payload: PayloadType, error: type[Exception]): + """Test serialization of NBT LIST tag with invalid value.""" + with pytest.raises(error): + ListNBT(payload, "test").serialize() + + +def test_deserialize_list_fail(): + """Test deserialization of NBT LIST tag with invalid value.""" + # Wrong tag type + buffer = Buffer(bytearray([0x09, 255, 0, 0, 0, 1, 0])) + with pytest.raises(TypeError): + ListNBT.deserialize(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([0x09, 1, 0, 0, 0, 1])) + with pytest.raises(IOError): + ListNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([0x09, 1, 0, 0, 0])) + with pytest.raises(IOError): + ListNBT.read_from(buffer, with_name=False) + + +# endregion +# region CompoundNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([ByteNBT(0, name="Hello"), IntNBT(0)], ValueError), + ([ByteNBT(0, name="hi"), "test"], ValueError), + ([ByteNBT(0, name="hi"), None], ValueError), + ([ByteNBT(0), ByteNBT(-1, "Hello World")], ValueError), # All unnamed tags + ( + [ByteNBT(128, name="Jello"), ByteNBT(-1, name="Bonjour")], + OverflowError, + ), # Check for error propagation + ], +) +def test_serialize_compound_fail(payload: PayloadType, error: type[Exception]): + """Test serialization of NBT COMPOUND tag with invalid value.""" + with pytest.raises(error): + CompoundNBT(payload, "test").serialize() + + # Double name + with pytest.raises(ValueError): + CompoundNBT([ByteNBT(0, name="test"), ByteNBT(0, name="test")], "comp").serialize() + + +def test_deseialize_compound_fail(): + """Test deserialization of NBT COMPOUND tag with invalid value.""" + # Not enough data + buffer = Buffer(bytearray([NBTagType.COMPOUND, 0x01])) + with pytest.raises(IOError): + CompoundNBT.read_from(buffer, with_name=False) + + # Not enough data + buffer = Buffer(bytearray([NBTagType.COMPOUND])) + with pytest.raises(IOError): + CompoundNBT.read_from(buffer, with_name=False) + + # Wrong tag type + buffer = Buffer(bytearray([15])) + with pytest.raises(TypeError): + NBTag.deserialize(buffer) + + +def test_nbtag_deserialize_compound(): + """Test deserialization of NBT COMPOUND tag from the NBTag class.""" + buf = Buffer(bytearray([0x00])) + assert NBTag.deserialize(buf, with_type=False, with_name=False) == CompoundNBT([]) + + buf = Buffer(bytearray.fromhex("0A 00 01 61 01 00 01 62 00 00")) + assert NBTag.deserialize(buf) == CompoundNBT([ByteNBT(0, name="b")], name="a") + + +def test_equality_compound(): + """Test equality of CompoundNBT.""" + comp1 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], + "comp", + ) + comp2 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], + "comp", + ) + assert comp1 == comp2 + + comp2 = CompoundNBT([ByteNBT(0, name="test"), ByteNBT(1, name="test2")], "comp") + assert comp1 != comp2 + + comp2 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test4")], + "comp", + ) + assert comp1 != comp2 + + comp2 = CompoundNBT( + [ByteNBT(0, name="test"), ByteNBT(1, name="test2"), ByteNBT(2, name="test3")], + "comp2", + ) + assert comp1 != comp2 + + assert comp1 != ByteNBT(0, name="comp") + + +# endregion +# region IntArrayNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([0, "test"], ValueError), + ([0, None], ValueError), + ([0, 1 << 31], OverflowError), + ([0, -(1 << 31) - 1], OverflowError), + ], +) +def test_serialize_intarray_fail(payload: PayloadType, error: type[Exception]): + """Test serialization of NBT INTARRAY tag with invalid value.""" + with pytest.raises(error): + IntArrayNBT(payload, "test").serialize() + + +def test_deserialize_intarray_fail(): + """Test deserialization of NBT INTARRAY tag with invalid value.""" + # Not enough data for 1 element + buffer = Buffer(bytearray([0x0B, 0, 0, 0, 1, 0, 0, 0])) + with pytest.raises(IOError): + IntArrayNBT.deserialize(buffer, with_name=False) + + # Not enough data for the size + buffer = Buffer(bytearray([0x0B, 0, 0, 0])) + with pytest.raises(IOError): + IntArrayNBT.read_from(buffer, with_name=False) + + # Not enough data to start the 2nd element + buffer = Buffer(bytearray([0x0B, 0, 0, 0, 2, 1, 0, 0, 0])) + with pytest.raises(IOError): + IntArrayNBT.read_from(buffer, with_name=False) + + +# endregion +# region LongArrayNBT + + +@pytest.mark.parametrize( + ("payload", "error"), + [ + ([0, "test"], ValueError), + ([0, None], ValueError), + ([0, 1 << 63], OverflowError), + ([0, -(1 << 63) - 1], OverflowError), + ], +) +def test_serialize_deserialize_longarray_fail(payload: PayloadType, error: type[Exception]): + """Test serialization/deserialization of NBT LONGARRAY tag with invalid value.""" + with pytest.raises(error): + LongArrayNBT(payload, "test").serialize() + + +def test_deserialize_longarray_fail(): + """Test deserialization of NBT LONGARRAY tag with invalid value.""" + # Not enough data for 1 element + buffer = Buffer(bytearray([0x0C, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(IOError): + LongArrayNBT.deserialize(buffer, with_name=False) + + # Not enough data for the size + buffer = Buffer(bytearray([0x0C, 0, 0, 0])) + with pytest.raises(IOError): + LongArrayNBT.read_from(buffer, with_name=False) + + # Not enough data to start the 2nd element + buffer = Buffer(bytearray([0x0C, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0])) + with pytest.raises(IOError): + LongArrayNBT.read_from(buffer, with_name=False) + + +# endregion + +# region NBTag + + +def test_nbt_helloworld(): + """Test serialization/deserialization of a simple NBT tag. + + Source data: https://wiki.vg/NBT#Example. + """ + data = bytearray.fromhex("0a000b68656c6c6f20776f726c640800046e616d65000942616e616e72616d6100") + buffer = Buffer(data) + + expected_object = { + "name": "Bananrama", + } + expected_schema = {"name": StringNBT} + + data = CompoundNBT.deserialize(buffer) + assert data == NBTag.from_object(expected_object, schema=expected_schema, name="hello world") + assert data.to_object() == expected_object + + +def test_nbt_bigfile(): + """Test serialization/deserialization of a big NBT tag. + + Slightly modified from the source data to also include a IntArrayNBT and a LongArrayNBT. + Source data: https://wiki.vg/NBT#Example. + """ + data = "0a00054c6576656c0400086c6f6e67546573747fffffffffffffff02000973686f7274546573747fff08000a737472696e6754657374002948454c4c4f20574f524c4420544849532049532041205445535420535452494e4720c385c384c39621050009666c6f6174546573743eff1832030007696e74546573747fffffff0a00146e657374656420636f6d706f756e6420746573740a000368616d0800046e616d65000648616d70757305000576616c75653f400000000a00036567670800046e616d6500074567676265727405000576616c75653f00000000000c000f6c6973745465737420286c6f6e672900000005000000000000000b000000000000000c000000000000000d000000000000000e7fffffffffffffff0b000e6c697374546573742028696e7429000000047fffffff7ffffffe7ffffffd7ffffffc0900136c697374546573742028636f6d706f756e64290a000000020800046e616d65000f436f6d706f756e642074616720233004000a637265617465642d6f6e000001265237d58d000800046e616d65000f436f6d706f756e642074616720233104000a637265617465642d6f6e000001265237d58d0001000862797465546573747f07006562797465417272617954657374202874686520666972737420313030302076616c756573206f6620286e2a6e2a3235352b6e2a3729253130302c207374617274696e672077697468206e3d302028302c2036322c2033342c2031362c20382c202e2e2e2929000003e8003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a0630003e2210080a162c4c12462004564e505c0e2e5828024a3830323e54103a0a482c1a12142036561c502a0e60585a02183862320c54423a3c485e1a44145236241c1e2a4060265a34180662000c2242083c165e4c44465204244e1e5c402e2628344a063005000a646f75626c65546573743efc000000" # noqa: E501 + data = bytes.fromhex(data) + buffer = Buffer(data) + + expected_object = { # Name ! Level + "longTest": 9223372036854775807, + "shortTest": 32767, + "stringTest": "HELLO WORLD THIS IS A TEST STRING ÅÄÖ!", + "floatTest": 0.4982314705848694, + "intTest": 2147483647, + "nested compound test": { + "ham": {"name": "Hampus", "value": 0.75}, + "egg": {"name": "Eggbert", "value": 0.5}, + }, + "listTest (long)": [11, 12, 13, 14, 9223372036854775807], + "listTest (int)": [2147483647, 2147483646, 2147483645, 2147483644], + "listTest (compound)": [ + {"name": "Compound tag #0", "created-on": 1264099775885}, + {"name": "Compound tag #1", "created-on": 1264099775885}, + ], + "byteTest": 127, + "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, " + "starting with n=0 (0, 62, 34, 16, 8, ...))": bytes((n * n * 255 + n * 7) % 100 for n in range(1000)), + "doubleTest": 0.4921875, + } + expected_schema = { + "longTest": LongNBT, + "shortTest": ShortNBT, + "stringTest": StringNBT, + "floatTest": FloatNBT, + "intTest": IntNBT, + "nested compound test": { + "ham": { + "name": StringNBT, + "value": FloatNBT, + }, + "egg": { + "name": StringNBT, + "value": FloatNBT, + }, + }, + "listTest (long)": LongArrayNBT, + "listTest (int)": IntArrayNBT, + "listTest (compound)": [ + { + "name": StringNBT, + "created-on": LongNBT, + } + ], + "byteTest": ByteNBT, + "byteArrayTest (the first 1000 values of (n*n*255+n*7)%100, " + "starting with n=0 (0, 62, 34, 16, 8, ...))": ByteArrayNBT, + "doubleTest": FloatNBT, + } + + data = CompoundNBT.deserialize(buffer) + # print(f"{data=}\n{expected_object=}\n{data.to_object()=}\n{NBTag.from_object(expected_object)=}") + + def check_equality(self: object, other: object) -> bool: + """Check if two objects are equal, with deep epsilon check for floats.""" + if type(self) != type(other): + return False + if isinstance(self, dict): + self = cast(Dict[Any, Any], self) + other = cast(Dict[Any, Any], other) + if len(self) != len(other): + return False + for key in self: + if key not in other: + return False + if not check_equality(self[key], other[key]): + return False + return True + if isinstance(self, list): + self = cast(List[Any], self) + other = cast(List[Any], other) + if len(self) != len(other): + return False + return all(check_equality(self[i], other[i]) for i in range(len(self))) + if isinstance(self, float) and isinstance(other, float): + return abs(self - other) < 1e-6 + if self != other: + return False + return self == other + + assert data == NBTag.from_object(expected_object, schema=expected_schema, name="Level") + assert check_equality(data.to_object(), expected_object) + + +# endregion +# region Edge cases + + +def test_from_object_morecases(): + """Test from_object with more edge cases.""" + + class CustomType: + def to_nbt(self, name: str = "") -> NBTag: + return ByteArrayNBT(b"CustomType", name) + + assert NBTag.from_object( + { + "number": ByteNBT(0), # ByteNBT + "bytearray": b"test", # Conversion from bytes + "empty_list": [], # Empty list with type EndNBT + "empty_compound": {}, # Empty compound + "custom": CustomType(), # Custom type with to_nbt method + "recursive_list": [ + [0, 1, 2], + [3, 4, 5], + ], + }, + { + "number": ByteNBT, + "bytearray": ByteArrayNBT, + "empty_list": [], + "empty_compound": {}, + "custom": CustomType, + "recursive_list": [[IntNBT], [ShortNBT]], + }, + ) == CompoundNBT( + [ # Order is shuffled because the spec does not require a specific order + CompoundNBT([], "empty_compound"), + ByteArrayNBT(b"test", "bytearray"), + ByteArrayNBT(b"CustomType", "custom"), + ListNBT([], "empty_list"), + ByteNBT(0, "number"), + ListNBT( + [ListNBT([IntNBT(0), IntNBT(1), IntNBT(2)]), ListNBT([ShortNBT(3), ShortNBT(4), ShortNBT(5)])], + "recursive_list", + ), + ] + ) + + compound = CompoundNBT.from_object( + {"test": 0, "test2": 0}, + {"test": ByteNBT, "test2": IntNBT}, + name="compound", + ) + + assert ListNBT([]).value == [] + assert compound.to_object(include_name=True) == {"compound": {"test": 0, "test2": 0}} + assert compound.value == {"test": 0, "test2": 0} + assert ListNBT([IntNBT(0)]).value == [0] + + assert ByteNBT(12).value == 12 + assert ShortNBT(13).value == 13 + assert IntNBT(14).value == 14 + assert LongNBT(15).value == 15 + assert FloatNBT(0.5).value == 0.5 + assert DoubleNBT(0.6).value == 0.6 + assert ByteArrayNBT(b"test").value == b"test" + assert StringNBT("test").value == "test" + assert IntArrayNBT([0, 1, 2]).value == [0, 1, 2] + assert LongArrayNBT([0, 1, 2, 3]).value == [0, 1, 2, 3] + + +@pytest.mark.parametrize( + ("data", "schema", "error", "error_msg"), + [ + # Data is not a list + ({"test": 0}, {"test": [ByteNBT]}, TypeError, "Expected a list, but found a different type."), + # Expected a list of dict, got a list of NBTags for schema + ( + {"test": [1, 0]}, + {"test": [ByteNBT, IntNBT]}, + TypeError, + "The schema must contain a single type of elements. .*", + ), + # Schema and data have different lengths + ( + [[1], [2], [3]], + [[ByteNBT], [IntNBT]], + ValueError, + "The schema and the data must have the same length.", + ), + ([1], [], ValueError, "The schema and the data must have the same length."), + # Schema is a dict, data is not + (["test"], {"test": ByteNBT}, TypeError, "Expected a dictionary, but found a different type."), + # Schema is not a dict, list or subclass of NBTagConvertible + ( + ["test"], + "test", + TypeError, + "The schema must be a list, dict, a subclass of NBTag or an object with a `to_nbt` method.", + ), + # Schema contains a mix of dict and list + ( + [{"test": 0}, [1, 2, 3]], + [{"test": ByteNBT}, [IntNBT]], + TypeError, + "Expected a list of lists or dictionaries, but found a different type", + ), + # Schema contains multiple types + ( + [[0], [-1]], + [IntArrayNBT, LongArrayNBT], + TypeError, + "The schema must contain a single type of elements.", + ), + # Schema contains CompoundNBT or ListNBT instead of a dict or list + ( + {"test": 0}, + CompoundNBT, + ValueError, + "Use a list or a dictionary in the schema to create a CompoundNBT or a ListNBT.", + ), + ( + ["test"], + ListNBT, + ValueError, + "Use a list or a dictionary in the schema to create a CompoundNBT or a ListNBT.", + ), + # The schema specifies a type, but the data is a dict with more than one key + ( + {"test": 0, "test2": 1}, + ByteNBT, + ValueError, + "Expected a dictionary with a single key-value pair.", + ), + # The data is not of the right type to be a payload + ( + {"test": object()}, + ByteNBT, + TypeError, + "Expected a bytes, str, int, float, but found object.", + ), + # The data is a list but not all elements are ints + ( + [0, "test"], + IntArrayNBT, + TypeError, + "Expected a list of integers.", + ), + ], +) +def test_from_object_error(data: Any, schema: Any, error: type[Exception], error_msg: str): + """Test from_object with erroneous data.""" + with pytest.raises(error, match=error_msg): + NBTag.from_object(data, schema) + + +def test_from_object_more_errors(): + """Test from_object with more edge cases.""" + # Redefine the name of the tag + schema = ByteNBT + data = {"test": 0} + with pytest.raises(ValueError): + NBTag.from_object(data, schema, name="othername") + + class CustomType: + def to_nbt(self, name: str = "") -> NBTag: + return ByteArrayNBT(b"CustomType", name) + + # Wrong data type + with pytest.raises(TypeError): + NBTag.from_object(0, CustomType) + + +def test_to_object_morecases(): + """Test to_object with more edge cases.""" + + class CustomType: + def to_nbt(self, name: str = "") -> NBTag: + return ByteArrayNBT(b"CustomType", name) + + assert NBTag.from_object( + { + "bytearray": b"test", + "empty_list": [], + "empty_compound": {}, + "custom": CustomType(), + "recursive_list": [ + [0, 1, 2], + [3, 4, 5], + ], + "compound_list": [{"test": 0, "test2": 1}, {"test2": 1}], + }, + { + "bytearray": ByteArrayNBT, + "empty_list": [], + "empty_compound": {}, + "custom": CustomType, + "recursive_list": [[IntNBT], [ShortNBT]], + "compound_list": [{"test": ByteNBT, "test2": IntNBT}, {"test2": IntNBT}], + }, + ).to_object(include_schema=True) == ( + { + "bytearray": b"test", + "empty_list": [], + "empty_compound": {}, + "custom": b"CustomType", + "recursive_list": [[0, 1, 2], [3, 4, 5]], + "compound_list": [{"test": 0, "test2": 1}, {"test2": 1}], + }, + { + "bytearray": ByteArrayNBT, + "empty_list": [], + "empty_compound": {}, + "custom": ByteArrayNBT, # After the conversion, the NBT tag is a ByteArrayNBT + "recursive_list": [[IntNBT], [ShortNBT]], + "compound_list": [{"test": ByteNBT, "test2": IntNBT}, {"test2": IntNBT}], + }, + ) + + assert FloatNBT(0.5).to_object() == 0.5 + assert FloatNBT(0.5, "Hello World").to_object(include_name=True) == {"Hello World": 0.5} + assert ByteArrayNBT(b"test").to_object() == b"test" # Do not add name when there is no name + assert StringNBT("test").to_object() == "test" + assert StringNBT("test", "name").to_object(include_name=True) == {"name": "test"} + assert ListNBT([ByteNBT(0), ByteNBT(1)]).to_object() == [0, 1] + assert ListNBT([ByteNBT(0), ByteNBT(1)], "name").to_object(include_name=True) == {"name": [0, 1]} + assert IntArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] + assert LongArrayNBT([0, 1, 2]).to_object() == [0, 1, 2] + + with pytest.raises(TypeError): + NBTag.to_object(CompoundNBT([])) + + with pytest.raises(TypeError): + ListNBT([CompoundNBT([]), ListNBT([])]).to_object(include_schema=True) + + with pytest.raises(TypeError): + ListNBT([IntNBT(0), ShortNBT(0)]).to_object(include_schema=True) + + +def test_data_conversions(): + """Test data conversions using the built-in functions.""" + assert int(IntNBT(-1)) == -1 + assert float(FloatNBT(0.5)) == 0.5 + assert str(StringNBT("test")) == "test" + assert bytes(ByteArrayNBT(b"test")) == b"test" + assert list(ListNBT([ByteNBT(0), ByteNBT(1)])) == [ByteNBT(0), ByteNBT(1)] + assert dict(CompoundNBT([ByteNBT(0, "first"), ByteNBT(1, "second")])) == { + "first": ByteNBT(0, "first"), + "second": ByteNBT(1, "second"), + } + assert list(IntArrayNBT([0, 1, 2])) == [0, 1, 2] + assert list(LongArrayNBT([0, 1, 2])) == [0, 1, 2] + + +def test_init_nbtag_directly(): + """Test initializing NBTag directly.""" + with pytest.raises(TypeError): + NBTag(0) # type: ignore # I know, that's what I'm testing + + +@pytest.mark.parametrize( + ("buffer_content", "tag_type"), + [ + ("01", EndNBT), + ("00 00", ByteNBT), + ("01 0000", ShortNBT), + ("02 00000000", IntNBT), + ("03 0000000000000000", LongNBT), + ("04 3F800000", FloatNBT), + ("05 3FF999999999999A", DoubleNBT), + ("06 00", ByteArrayNBT), + ("07 00", StringNBT), + ("08 00", ListNBT), + ("09 00", CompoundNBT), + ("0A 00", IntArrayNBT), + ("0B 00", LongArrayNBT), + ], +) +def test_wrong_type(buffer_content: str, tag_type: type[NBTag]): + """Test read_from with wrong tag type in the buffer.""" + buffer = Buffer(bytearray.fromhex(buffer_content)) + with pytest.raises(TypeError): + tag_type.read_from(buffer, with_name=False) + + +# endregion