diff --git a/changes/285.internal.1.md b/changes/285.internal.1.md index 9c57ea2a..5536712b 100644 --- a/changes/285.internal.1.md +++ b/changes/285.internal.1.md @@ -1,57 +1,34 @@ -- Changed the way `Serializable` classes are handled: - - Here is how a basic `Serializable` class looks like: - - ```python - @final - @dataclass - class ToyClass(Serializable): - """Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`.""" - - a: int - b: str | int - - @override - def __attrs_post_init__(self): - """Initialize the object.""" - if isinstance(self.b, int): - self.b = str(self.b) - - super().__attrs_post_init__() # This will call validate() - - @override - def serialize_to(self, buf: Buffer): - """Write the object to a buffer.""" - self.b = cast(str, self.b) # Handled by the __attrs_post_init__ method - buf.write_varint(self.a) - buf.write_utf(self.b) - - @classmethod - @override - def deserialize(cls, buf: Buffer) -> ToyClass: - """Deserialize the object from a buffer.""" - a = buf.read_varint() - if a == 0: - raise ZeroDivisionError("a must be non-zero") - b = buf.read_utf() - return cls(a, b) - - @override - def validate(self) -> None: - """Validate the object's attributes.""" - if self.a == 0: - raise ZeroDivisionError("a must be non-zero") - if len(self.b) > 10: - raise ValueError("b must be less than 10 characters") - - ``` - -The `Serializable` class implement the following methods: - - - `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer. - - `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer. - - And the following optional methods: - - - `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid. - - `__attrs_post_init__() -> None`: Initializes the object. Call `super().__attrs_post_init__()` to validate the object. +- **Function**: `gen_serializable_test` + - Generates tests for serializable classes, covering serialization, deserialization, validation, and error handling. + - **Parameters**: + - `context` (dict): Context to add the test functions to (usually `globals()`). + - `cls` (type): The serializable class to test. + - `fields` (list): Tuples of field names and types of the serializable class. + - `serialize_deserialize` (list, optional): Tuples for testing successful serialization/deserialization. + - `validation_fail` (list, optional): Tuples for testing validation failures with expected exceptions. + - `deserialization_fail` (list, optional): Tuples for testing deserialization failures with expected exceptions. + - **Note**: Implement `__eq__` in the class for accurate comparison. + + - The `gen_serializable_test` function generates a test class with the following tests: + +.. literalinclude:: /../tests/mcproto/utils/test_serializable.py + :language: python + :start-after: # region Test ToyClass + :end-before: # endregion Test ToyClass + + - The generated test class will have the following tests: + +```python +class TestGenToyClass: + def test_serialization(self): + # 3 subtests for the cases 1, 2, 3 (serialize_deserialize) + + def test_deserialization(self): + # 3 subtests for the cases 1, 2, 3 (serialize_deserialize) + + def test_validation(self): + # 3 subtests for the cases 4, 5, 6 (validation_fail) + + def test_exceptions(self): + # 3 subtests for the cases 7, 8, 9 (deserialization_fail) +``` diff --git a/changes/285.internal.2.md b/changes/285.internal.2.md index 511bb5db..7969f33c 100644 --- a/changes/285.internal.2.md +++ b/changes/285.internal.2.md @@ -1,53 +1,16 @@ -- Added a test generator for `Serializable` classes: - - The `gen_serializable_test` function generates tests for `Serializable` classes. It takes the following arguments: - - - `context`: The dictionary containing the context in which the generated test class will be placed (e.g. `globals()`). - > Dictionary updates must reflect in the context. This is the case for `globals()` but implementation-specific for `locals()`. - - `cls`: The `Serializable` class to generate tests for. - - `fields`: A list of fields where the test values will be placed. - - > In the example above, the `ToyClass` class has two fields: `a` and `b`. - - - `test_data`: A list of tuples containing either: - - `((field1_value, field2_value, ...), expected_bytes)`: The values of the fields and the expected serialized bytes. This needs to work both ways, i.e. `cls(field1_value, field2_value, ...) == cls.deserialize(expected_bytes).` - - `((field1_value, field2_value, ...), exception)`: The values of the fields and the expected exception when validating the object. - - `(exception, bytes)`: The expected exception when deserializing the bytes and the bytes to deserialize. - - The `gen_serializable_test` function generates a test class with the following tests: - -```python -gen_serializable_test( - context=globals(), - cls=ToyClass, - fields=[("a", int), ("b", str)], - test_data=[ - ((1, "hello"), b"\x01\x05hello"), - ((2, "world"), b"\x02\x05world"), - ((3, 1234567890), b"\x03\x0a1234567890"), - ((0, "hello"), ZeroDivisionError("a must be non-zero")), # With an error message - ((1, "hello world"), ValueError), # No error message - ((1, 12345678900), ValueError("b must be less than 10 .*")), # With an error message and regex - (ZeroDivisionError, b"\x00"), - (ZeroDivisionError, b"\x01\x05hello"), - (IOError, b"\x01"), - ], -) -``` - - The generated test class will have the following tests: - - ```python -class TestGenToyClass: - def test_serialization(self): - # 2 subtests for the cases 1 and 2 - - def test_deserialization(self): - # 2 subtests for the cases 1 and 2 - - def test_validation(self): - # 2 subtests for the cases 3 and 4 - - def test_exceptions(self): - # 2 subtests for the cases 5 and 6 - ``` +- **Class**: `Serializable` + - Base class for types that should be (de)serializable into/from `mcproto.Buffer` data. + - **Methods**: + - `__attrs_post_init__()`: Runs validation after object initialization, override to define custom behavior. + - `serialize() -> Buffer`: Returns the object as a `Buffer`. + - `serialize_to(buf: Buffer)`: Abstract method to write the object to a `Buffer`. + - `validate()`: Validates the object's attributes; can be overridden for custom validation. + - `deserialize(cls, buf: Buffer) -> Self`: Abstract method to construct the object from a `Buffer`. + - **Note**: Use the `dataclass` decorator when adding parameters to subclasses. + + - Exemple: + +.. literalinclude:: /../tests/mcproto/utils/test_serializable.py + :language: python + :start-after: # region ToyClass + :end-before: # endregion ToyClass diff --git a/mcproto/utils/abc.py b/mcproto/utils/abc.py index 2afe3832..bbe78a85 100644 --- a/mcproto/utils/abc.py +++ b/mcproto/utils/abc.py @@ -68,7 +68,7 @@ def __new__(cls: type[Self], *a: Any, **kw: Any) -> Self: class Serializable(ABC): """Base class for any type that should be (de)serializable into/from :class:`~mcproto.Buffer` data. - Any class that inherits from this class and adds parameters should use the :func:`~mcproto.utils.abc.dataclass` + Any class that inherits from this class and adds parameters should use the :func:`~mcproto.utils.abc.define` decorator. """ diff --git a/tests/helpers.py b/tests/helpers.py index 039338c0..9ecae3e2 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -2,9 +2,10 @@ import asyncio import inspect +import re import unittest.mock from collections.abc import Callable, Coroutine -from typing import Any, Generic, TypeVar +from typing import Any, Generic, NamedTuple, TypeVar from typing_extensions import TypeGuard import pytest @@ -17,7 +18,14 @@ P = ParamSpec("P") T_Mock = TypeVar("T_Mock", bound=unittest.mock.Mock) -__all__ = ["synchronize", "SynchronizedMixin", "UnpropagatingMockMixin", "CustomMockMixin", "gen_serializable_test"] +__all__ = [ + "synchronize", + "SynchronizedMixin", + "UnpropagatingMockMixin", + "CustomMockMixin", + "gen_serializable_test", + "TestExc", +] def synchronize(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]: @@ -169,27 +177,42 @@ def __init__(self, **kwargs): super().__init__(spec_set=self.spec_set, **kwargs) # type: ignore # Mixin class, this __init__ is valid -def isexception(obj: object) -> TypeGuard[type[Exception] | Exception]: +def isexception(obj: object) -> TypeGuard[type[Exception] | TestExc]: """Check if the object is an exception.""" - return (isinstance(obj, type) and issubclass(obj, Exception)) or isinstance(obj, Exception) + return (isinstance(obj, type) and issubclass(obj, Exception)) or isinstance(obj, TestExc) -def get_exception(exception: type[Exception] | Exception) -> tuple[type[Exception], str | None]: - """Get the exception type and message.""" - if isinstance(exception, type): - return exception, None - return type(exception), str(exception) +class TestExc(NamedTuple): + """Named tuple to check if an exception is raised with a specific message. + + :param exception: The exception type. + :param match: If specified, a string containing a regular expression, or a regular expression object, that is + tested against the string representation of the exception using :func:`re.search`. + + :param kwargs: The keyword arguments passed to the exception. + + If :attr:`kwargs` is not None, the exception instance will need to have the same attributes with the same values. + """ + + exception: type[Exception] | tuple[type[Exception], ...] + match: str | re.Pattern[str] | None = None + kwargs: dict[str, Any] | None = None + + @classmethod + def from_exception(cls, exception: type[Exception] | tuple[type[Exception], ...] | TestExc) -> TestExc: + """Create a :class:`TestExc` from an exception, does nothing if the object is already a :class:`TestExc`.""" + if isinstance(exception, TestExc): + return exception + return cls(exception) def gen_serializable_test( context: dict[str, Any], cls: type[Serializable], fields: list[tuple[str, type | str]], - test_data: list[ - tuple[tuple[Any, ...], bytes] - | tuple[tuple[Any, ...], type[Exception] | Exception] - | tuple[type[Exception] | Exception, bytes] - ], + serialize_deserialize: list[tuple[tuple[Any, ...], bytes]] | None = None, + validation_fail: list[tuple[tuple[Any, ...], type[Exception] | TestExc]] | None = None, + deserialization_fail: list[tuple[bytes, type[Exception] | TestExc]] | None = None, ): """Generate tests for a serializable class. @@ -199,15 +222,14 @@ def gen_serializable_test( :param context: The context to add the test functions to. This is usually `globals()`. :param cls: The serializable class to test. :param fields: A list of tuples containing the field names and types of the serializable class. - :param test_data: A list of test data. Each element is a tuple containing either: - - A tuple of parameters to pass to the serializable class constructor and the expected bytes after - serialization - - A tuple of parameters to pass to the serializable class constructor and the expected exception during - validation - - An exception to expect during deserialization and the bytes to deserialize - - Exception can be either a type or an instance of an exception, in the latter case the exception message will - be used to match the exception, and can contain regex patterns. + :param serialize_deserialize: A list of tuples containing: + - The tuple representing the arguments to pass to the :class:`mcproto.utils.abc.Serializable` class + - The expected bytes + :param validation_fail: A list of tuples containing the arguments to pass to the + :class:`mcproto.utils.abc.Serializable` class and the expected exception, either as is or wrapped in a + :class:`TestExc` object. + :param deserialization_fail: A list of tuples containing the bytes to pass to the :meth:`deserialize` method of the + class and the expected exception, either as is or wrapped in a :class:`TestExc` object. Example usage: @@ -221,28 +243,30 @@ def gen_serializable_test( .. note:: The test cases will use :meth:`__eq__` to compare the objects, so make sure to implement it in the class if - you are not using a dataclass. + you are not using the autogenerated method from :func:`attrs.define`. """ - # Separate the test data into parameters for each test function # This holds the parameters for the serialization and deserialization tests parameters: list[tuple[dict[str, Any], bytes]] = [] # This holds the parameters for the validation tests - validation_fail: list[tuple[dict[str, Any], type[Exception] | Exception]] = [] + validation_fail_kw: list[tuple[dict[str, Any], TestExc]] = [] + + for data, exp_bytes in [] if serialize_deserialize is None else serialize_deserialize: + kwargs = dict(zip([f[0] for f in fields], data)) + parameters.append((kwargs, exp_bytes)) - # This holds the parameters for the deserialization error tests - deserialization_fail: list[tuple[bytes, type[Exception] | Exception]] = [] + for data, exc in [] if validation_fail is None else validation_fail: + kwargs = dict(zip([f[0] for f in fields], data)) + exc_wrapped = TestExc.from_exception(exc) + validation_fail_kw.append((kwargs, exc_wrapped)) - for data_or_exc, expected_bytes_or_exc in test_data: - if isinstance(data_or_exc, tuple) and isinstance(expected_bytes_or_exc, bytes): - kwargs = dict(zip([f[0] for f in fields], data_or_exc)) - parameters.append((kwargs, expected_bytes_or_exc)) - elif isexception(data_or_exc) and isinstance(expected_bytes_or_exc, bytes): - deserialization_fail.append((expected_bytes_or_exc, data_or_exc)) - elif isinstance(data_or_exc, tuple) and isexception(expected_bytes_or_exc): - kwargs = dict(zip([f[0] for f in fields], data_or_exc)) - validation_fail.append((kwargs, expected_bytes_or_exc)) + # Just make sure that the exceptions are wrapped in TestExc + deserialization_fail = ( + [] + if deserialization_fail is None + else [(data, TestExc.from_exception(exc)) for data, exc in deserialization_fail] + ) def generate_name(param: dict[str, Any] | bytes, i: int) -> str: """Generate a name for the test case.""" @@ -301,33 +325,45 @@ def test_deserialization(self, kwargs: dict[str, Any], expected_bytes: bytes): @pytest.mark.parametrize( ("kwargs", "exc"), - validation_fail, - ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(validation_fail)), + validation_fail_kw, + ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(validation_fail_kw)), ) - def test_validation(self, kwargs: dict[str, Any], exc: type[Exception] | Exception): + def test_validation(self, kwargs: dict[str, Any], exc: TestExc): """Test validation of the object.""" - exc, msg = get_exception(exc) - with pytest.raises(exc, match=msg): + with pytest.raises(exc.exception, match=exc.match) as exc_info: cls(**kwargs) + # If exc.kwargs is not None, check them against the exception + if exc.kwargs is not None: + for key, value in exc.kwargs.items(): + assert value == getattr( + exc_info.value, key + ), f"{key}: {value!r} != {getattr(exc_info.value, key)!r}" + @pytest.mark.parametrize( ("content", "exc"), deserialization_fail, ids=tuple(generate_name(content, i) for i, (content, _) in enumerate(deserialization_fail)), ) - def test_deserialization_error(self, content: bytes, exc: type[Exception] | Exception): + def test_deserialization_error(self, content: bytes, exc: TestExc): """Test deserialization error handling.""" buf = Buffer(content) - exc, msg = get_exception(exc) - with pytest.raises(exc, match=msg): + with pytest.raises(exc.exception, match=exc.match) as exc_info: cls.deserialize(buf) + # If exc.kwargs is not None, check them against the exception + if exc.kwargs is not None: + for key, value in exc.kwargs.items(): + assert value == getattr( + exc_info.value, key + ), f"{key}: {value!r} != {getattr(exc_info.value, key)!r}" + if len(parameters) == 0: # If there are no serialization tests, remove them del TestClass.test_serialization del TestClass.test_deserialization - if len(validation_fail) == 0: + if len(validation_fail_kw) == 0: # If there are no validation tests, remove them del TestClass.test_validation diff --git a/tests/mcproto/packets/handshaking/test_handshake.py b/tests/mcproto/packets/handshaking/test_handshake.py index 53c5de4a..6fe0d525 100644 --- a/tests/mcproto/packets/handshaking/test_handshake.py +++ b/tests/mcproto/packets/handshaking/test_handshake.py @@ -12,7 +12,7 @@ ("server_port", int), ("next_state", NextState), ], - test_data=[ + serialize_deserialize=[ ( (757, "mc.aircs.racing", 25565, NextState.LOGIN), bytes.fromhex("f5050f6d632e61697263732e726163696e6763dd02"), @@ -29,6 +29,8 @@ (757, "hypixel.net", 25565, NextState.STATUS), bytes.fromhex("f5050b6879706978656c2e6e657463dd01"), ), + ], + validation_fail=[ # Invalid next state ((757, "localhost", 25565, 3), ValueError), ], diff --git a/tests/mcproto/packets/login/test_login.py b/tests/mcproto/packets/login/test_login.py index 71067022..33ad8095 100644 --- a/tests/mcproto/packets/login/test_login.py +++ b/tests/mcproto/packets/login/test_login.py @@ -21,7 +21,7 @@ context=globals(), cls=LoginStart, fields=[("username", str), ("uuid", UUID)], - test_data=[ + serialize_deserialize=[ ( ("ItsDrike", UUID("f70b4a42c9a04ffb92a31390c128a1b2")), bytes.fromhex("084974734472696b65f70b4a42c9a04ffb92a31390c128a1b2"), @@ -38,7 +38,7 @@ context=globals(), cls=LoginEncryptionRequest, fields=[("server_id", str), ("public_key", bytes), ("verify_token", bytes)], - test_data=[ + serialize_deserialize=[ ( ("a" * 20, RSA_PUBLIC_KEY, bytes.fromhex("9bd416ef")), bytes.fromhex( @@ -48,7 +48,9 @@ "4c3938a298da575e12e0ae178d61a69bc0ea0b381790f182d9dba715bfb503c99d92b0203010001049bd416ef" ), ), - (InvalidPacketContentError, bytes.fromhex("14")), + ], + deserialization_fail=[ + (bytes.fromhex("14"), InvalidPacketContentError), ], ) @@ -64,7 +66,7 @@ def test_login_encryption_request_noid(): context=globals(), cls=LoginEncryptionResponse, fields=[("shared_secret", bytes), ("verify_token", bytes)], - test_data=[ + serialize_deserialize=[ ( (b"I'm shared", b"Token"), bytes.fromhex("0a49276d2073686172656405546f6b656e"), @@ -78,7 +80,7 @@ def test_login_encryption_request_noid(): context=globals(), cls=LoginSuccess, fields=[("uuid", UUID), ("username", str)], - test_data=[ + serialize_deserialize=[ ( (UUID("f70b4a42c9a04ffb92a31390c128a1b2"), "Mario"), bytes.fromhex("f70b4a42c9a04ffb92a31390c128a1b2054d6172696f"), @@ -91,7 +93,7 @@ def test_login_encryption_request_noid(): context=globals(), cls=LoginDisconnect, fields=[("reason", ChatMessage)], - test_data=[ + serialize_deserialize=[ ( (ChatMessage("You are banned."),), bytes.fromhex("1122596f75206172652062616e6e65642e22"), @@ -105,7 +107,7 @@ def test_login_encryption_request_noid(): context=globals(), cls=LoginPluginRequest, fields=[("message_id", int), ("channel", str), ("data", bytes)], - test_data=[ + serialize_deserialize=[ ( (0, "xyz", b"Hello"), bytes.fromhex("000378797a48656c6c6f"), @@ -119,7 +121,7 @@ def test_login_encryption_request_noid(): context=globals(), cls=LoginPluginResponse, fields=[("message_id", int), ("data", bytes)], - test_data=[ + serialize_deserialize=[ ( (0, b"Hello"), bytes.fromhex("000148656c6c6f"), @@ -132,7 +134,7 @@ def test_login_encryption_request_noid(): context=globals(), cls=LoginSetCompression, fields=[("threshold", int)], - test_data=[ + serialize_deserialize=[ ( (2,), bytes.fromhex("02"), diff --git a/tests/mcproto/packets/status/test_ping.py b/tests/mcproto/packets/status/test_ping.py index 245a03aa..a0b753ef 100644 --- a/tests/mcproto/packets/status/test_ping.py +++ b/tests/mcproto/packets/status/test_ping.py @@ -7,7 +7,7 @@ context=globals(), cls=PingPong, fields=[("payload", int)], - test_data=[ + serialize_deserialize=[ ( (2806088,), bytes.fromhex("00000000002ad148"), diff --git a/tests/mcproto/packets/status/test_status.py b/tests/mcproto/packets/status/test_status.py index 19ce720c..8bc377d4 100644 --- a/tests/mcproto/packets/status/test_status.py +++ b/tests/mcproto/packets/status/test_status.py @@ -8,7 +8,7 @@ context=globals(), cls=StatusResponse, fields=[("data", "dict[str, Any]")], - test_data=[ + serialize_deserialize=[ ( ( { @@ -24,6 +24,8 @@ "16d65223a22312e31382e31222c2270726f746f636f6c223a3735377d7d" ), ), + ], + validation_fail=[ # Unserializable data for JSON (({"data": object()},), ValueError), ], diff --git a/tests/mcproto/types/test_chat.py b/tests/mcproto/types/test_chat.py index 3c8f05a6..54f16583 100644 --- a/tests/mcproto/types/test_chat.py +++ b/tests/mcproto/types/test_chat.py @@ -53,7 +53,7 @@ def test_equality(raw1: RawChatMessage, raw2: RawChatMessage, expected_result: b context=globals(), cls=ChatMessage, fields=[("raw", RawChatMessage)], - test_data=[ + serialize_deserialize=[ ( ("A Minecraft Server",), bytes.fromhex("142241204d696e6563726166742053657276657222"), @@ -66,6 +66,8 @@ def test_equality(raw1: RawChatMessage, raw2: RawChatMessage, expected_result: b ([{"text": "abc"}, {"text": "def"}],), bytes.fromhex("225b7b2274657874223a2022616263227d2c207b2274657874223a2022646566227d5d"), ), + ], + validation_fail=[ # Wrong type for raw ((b"invalid",), TypeError), (({"no_extra_or_text": "invalid"},), AttributeError), diff --git a/tests/mcproto/types/test_nbt.py b/tests/mcproto/types/test_nbt.py index 8211e7b8..10bacbb1 100644 --- a/tests/mcproto/types/test_nbt.py +++ b/tests/mcproto/types/test_nbt.py @@ -30,9 +30,11 @@ context=globals(), cls=EndNBT, fields=[], - test_data=[ + serialize_deserialize=[ ((), b"\x00"), - (IOError, b"\x01"), + ], + deserialization_fail=[ + (b"\x01", IOError), ], ) @@ -44,20 +46,24 @@ context=globals(), cls=ByteNBT, fields=[("payload", int), ("name", str)], - test_data=[ + serialize_deserialize=[ ((0, "a"), b"\x01\x00\x01a\x00"), ((1, "test"), b"\x01\x00\x04test\x01"), ((127, "&à@é"), b"\x01\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f"), ((-128, "test"), b"\x01\x00\x04test\x80"), ((-1, "a" * 100), b"\x01\x00\x64" + b"a" * 100 + b"\xff"), + ], + deserialization_fail=[ # Errors - (IOError, b"\x01\x00\x04test"), - (IOError, b"\x01\x00\x04tes"), - (IOError, b"\x01\x00"), - (IOError, b"\x01"), + (b"\x01\x00\x04test", IOError), + (b"\x01\x00\x04tes", IOError), + (b"\x01\x00", IOError), + (b"\x01", IOError), # Wrong type - (TypeError, b"\x02\x00\x01a\x00"), - (TypeError, b"\xff\x00\x01a\x00"), + (b"\x02\x00\x01a\x00", TypeError), + (b"\xff\x00\x01a\x00", TypeError), + ], + validation_fail=[ # Out of bounds ((1 << 7, "a"), OverflowError), ((-(1 << 7) - 1, "a"), OverflowError), @@ -72,17 +78,21 @@ context=globals(), cls=ShortNBT, fields=[("payload", int), ("name", str)], - test_data=[ + serialize_deserialize=[ ((0, "a"), b"\x02\x00\x01a\x00\x00"), ((1, "test"), b"\x02\x00\x04test\x00\x01"), ((32767, "&à@é"), b"\x02\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f\xff"), ((-32768, "test"), b"\x02\x00\x04test\x80\x00"), ((-1, "a" * 100), b"\x02\x00\x64" + b"a" * 100 + b"\xff\xff"), + ], + deserialization_fail=[ # Errors - (IOError, b"\x02\x00\x04test"), - (IOError, b"\x02\x00\x04tes"), - (IOError, b"\x02\x00"), - (IOError, b"\x02"), + (b"\x02\x00\x04test", IOError), + (b"\x02\x00\x04tes", IOError), + (b"\x02\x00", IOError), + (b"\x02", IOError), + ], + validation_fail=[ # Out of bounds ((1 << 15, "a"), OverflowError), ((-(1 << 15) - 1, "a"), OverflowError), @@ -96,17 +106,21 @@ context=globals(), cls=IntNBT, fields=[("payload", int), ("name", str)], - test_data=[ + serialize_deserialize=[ ((0, "a"), b"\x03\x00\x01a\x00\x00\x00\x00"), ((1, "test"), b"\x03\x00\x04test\x00\x00\x00\x01"), ((2147483647, "&à@é"), b"\x03\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f\xff\xff\xff"), ((-2147483648, "test"), b"\x03\x00\x04test\x80\x00\x00\x00"), ((-1, "a" * 100), b"\x03\x00\x64" + b"a" * 100 + b"\xff\xff\xff\xff"), + ], + deserialization_fail=[ # Errors - (IOError, b"\x03\x00\x04test"), - (IOError, b"\x03\x00\x04tes"), - (IOError, b"\x03\x00"), - (IOError, b"\x03"), + (b"\x03\x00\x04test", IOError), + (b"\x03\x00\x04tes", IOError), + (b"\x03\x00", IOError), + (b"\x03", IOError), + ], + validation_fail=[ # Out of bounds ((1 << 31, "a"), OverflowError), ((-(1 << 31) - 1, "a"), OverflowError), @@ -120,17 +134,21 @@ context=globals(), cls=LongNBT, fields=[("payload", int), ("name", str)], - test_data=[ + serialize_deserialize=[ ((0, "a"), b"\x04\x00\x01a\x00\x00\x00\x00\x00\x00\x00\x00"), ((1, "test"), b"\x04\x00\x04test\x00\x00\x00\x00\x00\x00\x00\x01"), (((1 << 63) - 1, "&à@é"), b"\x04\x00\x06" + bytes("&à@é", "utf-8") + b"\x7f\xff\xff\xff\xff\xff\xff\xff"), ((-1 << 63, "test"), b"\x04\x00\x04test\x80\x00\x00\x00\x00\x00\x00\x00"), ((-1, "a" * 100), b"\x04\x00\x64" + b"a" * 100 + b"\xff\xff\xff\xff\xff\xff\xff\xff"), + ], + deserialization_fail=[ # Errors - (IOError, b"\x04\x00\x04test"), - (IOError, b"\x04\x00\x04tes"), - (IOError, b"\x04\x00"), - (IOError, b"\x04"), + (b"\x04\x00\x04test", IOError), + (b"\x04\x00\x04tes", IOError), + (b"\x04\x00", IOError), + (b"\x04", IOError), + ], + validation_fail=[ # Out of bounds ((1 << 63, "a"), OverflowError), ((-(1 << 63) - 1, "a"), OverflowError), @@ -139,23 +157,28 @@ ], ) + # endregion # region Floating point NBT tests gen_serializable_test( context=globals(), cls=FloatNBT, fields=[("payload", float), ("name", str)], - test_data=[ + serialize_deserialize=[ ((1.0, "a"), b"\x05\x00\x01a" + bytes(struct.pack(">f", 1.0))), ((0.5, "test"), b"\x05\x00\x04test" + bytes(struct.pack(">f", 0.5))), # has to be convertible to float exactly ((-1.0, "&à@é"), b"\x05\x00\x06" + bytes("&à@é", "utf-8") + bytes(struct.pack(">f", -1.0))), ((12.0, "a" * 100), b"\x05\x00\x64" + b"a" * 100 + bytes(struct.pack(">f", 12.0))), ((1, "a"), b"\x05\x00\x01a" + bytes(struct.pack(">f", 1.0))), + ], + deserialization_fail=[ # Errors - (IOError, b"\x05\x00\x04test"), - (IOError, b"\x05\x00\x04tes"), - (IOError, b"\x05\x00"), - (IOError, b"\x05"), + (b"\x05\x00\x04test", IOError), + (b"\x05\x00\x04tes", IOError), + (b"\x05\x00", IOError), + (b"\x05", IOError), + ], + validation_fail=[ # Wrong type (("1.5", "a"), TypeError), ], @@ -165,26 +188,29 @@ context=globals(), cls=DoubleNBT, fields=[("payload", float), ("name", str)], - test_data=[ + serialize_deserialize=[ ((1.0, "a"), b"\x06\x00\x01a" + bytes(struct.pack(">d", 1.0))), ((3.14, "test"), b"\x06\x00\x04test" + bytes(struct.pack(">d", 3.14))), ((-1.0, "&à@é"), b"\x06\x00\x06" + bytes("&à@é", "utf-8") + bytes(struct.pack(">d", -1.0))), ((12.0, "a" * 100), b"\x06\x00\x64" + b"a" * 100 + bytes(struct.pack(">d", 12.0))), + ], + deserialization_fail=[ # Errors - (IOError, b"\x06\x00\x04test\x01"), - (IOError, b"\x06\x00\x04test"), - (IOError, b"\x06\x00\x04tes"), - (IOError, b"\x06\x00"), - (IOError, b"\x06"), + (b"\x06\x00\x04test\x01", IOError), + (b"\x06\x00\x04test", IOError), + (b"\x06\x00\x04tes", IOError), + (b"\x06\x00", IOError), + (b"\x06", IOError), ], ) + # endregion # region Variable Length NBT tests gen_serializable_test( context=globals(), cls=ByteArrayNBT, fields=[("payload", bytes), ("name", str)], - test_data=[ + serialize_deserialize=[ ((b"", "a"), b"\x07\x00\x01a\x00\x00\x00\x00"), ((b"\x00", "test"), b"\x07\x00\x04test\x00\x00\x00\x01\x00"), ((b"\x00\x01", "&à@é"), b"\x07\x00\x06" + bytes("&à@é", "utf-8") + b"\x00\x00\x00\x02\x00\x01"), @@ -192,15 +218,19 @@ ((b"\xff" * 1024, "a" * 100), b"\x07\x00\x64" + b"a" * 100 + b"\x00\x00\x04\x00" + b"\xff" * 1024), ((b"Hello World", "test"), b"\x07\x00\x04test\x00\x00\x00\x0b" + b"Hello World"), ((bytearray(b"Hello World"), "test"), b"\x07\x00\x04test\x00\x00\x00\x0b" + b"Hello World"), + ], + deserialization_fail=[ # Errors - (IOError, b"\x07\x00\x04test"), - (IOError, b"\x07\x00\x04tes"), - (IOError, b"\x07\x00"), - (IOError, b"\x07"), - (IOError, b"\x07\x00\x01a\x00\x01"), - (IOError, b"\x07\x00\x01a\x00\x00\x00\xff"), + (b"\x07\x00\x04test", IOError), + (b"\x07\x00\x04tes", IOError), + (b"\x07\x00", IOError), + (b"\x07", IOError), + (b"\x07\x00\x01a\x00\x01", IOError), + (b"\x07\x00\x01a\x00\x00\x00\xff", IOError), # Negative length - (ValueError, b"\x07\x00\x01a\xff\xff\xff\xff"), + (b"\x07\x00\x01a\xff\xff\xff\xff", ValueError), + ], + validation_fail=[ # Wrong type ((1, "a"), TypeError), ], @@ -210,20 +240,24 @@ context=globals(), cls=StringNBT, fields=[("payload", str), ("name", str)], - test_data=[ + serialize_deserialize=[ (("", "a"), b"\x08\x00\x01a\x00\x00"), (("test", "a"), b"\x08\x00\x01a\x00\x04" + b"test"), (("a" * 100, "&à@é"), b"\x08\x00\x06" + bytes("&à@é", "utf-8") + b"\x00\x64" + b"a" * 100), (("&à@é", "test"), b"\x08\x00\x04test\x00\x06" + bytes("&à@é", "utf-8")), + ], + deserialization_fail=[ # Errors - (IOError, b"\x08\x00\x04test"), - (IOError, b"\x08\x00\x04tes"), - (IOError, b"\x08\x00"), - (IOError, b"\x08"), - # Negative length - (ValueError, b"\x08\xff\xff\xff\xff"), + (b"\x08\x00\x04test", IOError), + (b"\x08\x00\x04tes", IOError), + (b"\x08\x00", IOError), + (b"\x08", IOError), # Unicode decode error - (UnicodeDecodeError, b"\x08\x00\x01a\x00\x01\xff"), + (b"\x08\x00\x01a\x00\x01\xff", UnicodeDecodeError), + (b"\x08\xff\xff\xff\xff", ValueError), + ], + validation_fail=[ + # Negative length # String too long (("a" * 32768, "b"), ValueError), # Wrong type @@ -237,8 +271,7 @@ context=globals(), cls=ListNBT, fields=[("payload", list), ("name", str)], - test_data=[ - # Here we only want to test ListNBT related stuff + serialize_deserialize=[ (([], "a"), b"\x09\x00\x01a\x00\x00\x00\x00\x00"), (([ByteNBT(-1)], "a"), b"\x09\x00\x01a\x01\x00\x00\x00\x01\xff"), (([ListNBT([])], "a"), b"\x09\x00\x01a\x09\x00\x00\x00\x01" + ListNBT([]).serialize()[1:]), @@ -255,15 +288,18 @@ + ListNBT([ByteNBT(-1)]).serialize()[1:] + ListNBT([IntNBT(128), IntNBT(8)]).serialize()[1:], ), - # Errors + ], + deserialization_fail=[ # Not enough data - (IOError, b"\x09\x00\x01a"), - (IOError, b"\x09\x00\x01a\x01"), - (IOError, b"\x09\x00\x01a\x01\x00"), - (IOError, b"\x09\x00\x01a\x01\x00\x00\x00\x01"), - (IOError, b"\x09\x00\x01a\x01\x00\x00\x00\x03\x01"), + (b"\x09\x00\x01a", IOError), + (b"\x09\x00\x01a\x01", IOError), + (b"\x09\x00\x01a\x01\x00", IOError), + (b"\x09\x00\x01a\x01\x00\x00\x00\x01", IOError), + (b"\x09\x00\x01a\x01\x00\x00\x00\x03\x01", IOError), # Invalid tag type - (TypeError, b"\x09\x00\x01a\xff\x00\x00\x01\x00"), + (b"\x09\x00\x01a\xff\x00\x00\x01\x00", TypeError), + ], + validation_fail=[ # Not NBTags (([1, 2, 3], "a"), TypeError), # Not the same tag type @@ -279,7 +315,7 @@ context=globals(), cls=CompoundNBT, fields=[("payload", list), ("name", str)], - test_data=[ + serialize_deserialize=[ (([], "a"), b"\x0a\x00\x01a\x00"), (([ByteNBT(0, name="Byte")], "a"), b"\x0a\x00\x01a" + ByteNBT(0, name="Byte").serialize() + b"\x00"), ( @@ -294,15 +330,18 @@ ([ListNBT([ByteNBT(0)] * 3, name="List")], "a"), b"\x0a\x00\x01a" + ListNBT([ByteNBT(0)] * 3, name="List").serialize() + b"\x00", ), - # Errors + ], + deserialization_fail=[ # Not enough data - (IOError, b"\x0a\x00\x01a"), - (IOError, b"\x0a\x00\x01a\x01"), - # All muse be NBTags + (b"\x0a\x00\x01a", IOError), + (b"\x0a\x00\x01a\x01", IOError), + ], + validation_fail=[ + # All must be NBTags (([0, 1, 2], "a"), TypeError), # All with a name (([ByteNBT(0)], "a"), ValueError), - # Must be unique + # Names must be unique (([ByteNBT(0, name="Byte"), ByteNBT(0, name="Byte")], "a"), ValueError), # Wrong type ((1, "a"), TypeError), @@ -313,19 +352,22 @@ context=globals(), cls=IntArrayNBT, fields=[("payload", list), ("name", str)], - test_data=[ + serialize_deserialize=[ (([], "a"), b"\x0b\x00\x01a\x00\x00\x00\x00"), (([0], "a"), b"\x0b\x00\x01a\x00\x00\x00\x01\x00\x00\x00\x00"), (([0, 1], "a"), b"\x0b\x00\x01a\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01"), (([1, 2, 3], "a"), b"\x0b\x00\x01a\x00\x00\x00\x03\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03"), (([(1 << 31) - 1], "a"), b"\x0b\x00\x01a\x00\x00\x00\x01\x7f\xff\xff\xff"), (([-1, -2, -3], "a"), b"\x0b\x00\x01a\x00\x00\x00\x03\xff\xff\xff\xff\xff\xff\xff\xfe\xff\xff\xff\xfd"), - # Errors + ], + deserialization_fail=[ # Not enough data - (IOError, b"\x0b\x00\x01a"), - (IOError, b"\x0b\x00\x01a\x01"), - (IOError, b"\x0b\x00\x01a\x00\x00\x00\x01"), - (IOError, b"\x0b\x00\x01a\x00\x00\x00\x03\x01"), + (b"\x0b\x00\x01a", IOError), + (b"\x0b\x00\x01a\x01", IOError), + (b"\x0b\x00\x01a\x00\x00\x00\x01", IOError), + (b"\x0b\x00\x01a\x00\x00\x00\x03\x01", IOError), + ], + validation_fail=[ # Must contain ints only ((["a"], "a"), TypeError), (([IntNBT(0)], "a"), TypeError), @@ -335,11 +377,12 @@ ((1, "a"), TypeError), ], ) + gen_serializable_test( context=globals(), cls=LongArrayNBT, fields=[("payload", list), ("name", str)], - test_data=[ + serialize_deserialize=[ (([], "a"), b"\x0c\x00\x01a\x00\x00\x00\x00"), (([0], "a"), b"\x0c\x00\x01a\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00"), ( @@ -351,12 +394,16 @@ ([-1, -2], "a"), b"\x0c\x00\x01a\x00\x00\x00\x02\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe", ), + ], + deserialization_fail=[ # Not enough data - (IOError, b"\x0c\x00\x01a"), - (IOError, b"\x0c\x00\x01a\x01"), - (IOError, b"\x0c\x00\x01a\x00\x00\x00\x01"), - (IOError, b"\x0c\x00\x01a\x00\x00\x00\x03\x01"), - # Must contain ints only + (b"\x0c\x00\x01a", IOError), + (b"\x0c\x00\x01a\x01", IOError), + (b"\x0c\x00\x01a\x00\x00\x00\x01", IOError), + (b"\x0c\x00\x01a\x00\x00\x00\x03\x01", IOError), + ], + validation_fail=[ + # Must contain longs only ((["a"], "a"), TypeError), (([LongNBT(0)], "a"), TypeError), (([1 << 63], "a"), OverflowError), diff --git a/tests/mcproto/types/test_uuid.py b/tests/mcproto/types/test_uuid.py index ffdd7e7e..a34fa7da 100644 --- a/tests/mcproto/types/test_uuid.py +++ b/tests/mcproto/types/test_uuid.py @@ -7,14 +7,18 @@ context=globals(), cls=UUID, fields=[("hex", str)], - test_data=[ + serialize_deserialize=[ (("12345678-1234-5678-1234-567812345678",), bytes.fromhex("12345678123456781234567812345678")), + ], + validation_fail=[ # Too short or too long (("12345678-1234-5678-1234-56781234567",), ValueError), (("12345678-1234-5678-1234-5678123456789",), ValueError), + ], + deserialization_fail=[ # Not enough data in the buffer (needs 16 bytes) - (IOError, b""), - (IOError, b"\x01"), - (IOError, b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e"), + (b"", IOError), + (b"\x01", IOError), + (b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e", IOError), ], ) diff --git a/tests/mcproto/utils/test_serializable.py b/tests/mcproto/utils/test_serializable.py index 2638e4cd..62381808 100644 --- a/tests/mcproto/utils/test_serializable.py +++ b/tests/mcproto/utils/test_serializable.py @@ -1,11 +1,21 @@ from __future__ import annotations -from typing import cast, final +from typing import Any, cast, final from typing_extensions import override from mcproto.buffer import Buffer from mcproto.utils.abc import Serializable, define -from tests.helpers import gen_serializable_test +from tests.helpers import gen_serializable_test, TestExc + + +class CustomError(Exception): + """Custom exception for testing.""" + + additional_data: Any + + def __init__(self, message: str, additional_data: Any): + super().__init__(message) + self.additional_data = additional_data # region ToyClass @@ -37,7 +47,7 @@ def deserialize(cls, buf: Buffer) -> ToyClass: """Deserialize the object from a buffer.""" a = buf.read_varint() if a == 0: - raise ZeroDivisionError("a must be non-zero") + raise CustomError("a must be non-zero", additional_data=a) b = buf.read_utf() return cls(a, b) @@ -51,26 +61,30 @@ def validate(self) -> None: raise ValueError("b must be less than 10 characters") -# endregion +# endregion ToyClass # region Test ToyClass gen_serializable_test( context=globals(), cls=ToyClass, fields=[("a", int), ("b", str)], - test_data=[ + serialize_deserialize=[ ((1, "hello"), b"\x01\x05hello"), ((2, "world"), b"\x02\x05world"), ((3, 1234567890), b"\x03\x0a1234567890"), - ((0, "hello"), ZeroDivisionError("a must be non-zero")), # Message specified + ], + validation_fail=[ + ((0, "hello"), TestExc(ZeroDivisionError, "a must be non-zero")), # Message specified ((1, "hello world"), ValueError), # No message specified - ((1, 12345678900), ValueError("b must be less than .*")), # Message specified with regex - (ZeroDivisionError, b"\x00"), - (ZeroDivisionError, b"\x00\x05hello"), - (IOError, b"\x01"), + ((1, 12345678900), TestExc(ValueError, "b must be less than .*")), # Message regex + ], + deserialization_fail=[ + (b"\x00", CustomError), # No message specified + (b"\x00\x05hello", TestExc(CustomError, "a must be non-zero", {"additional_data": 0})), # Check fields + (b"\x01", TestExc(IOError)), # No message specified ], ) -# endregion +# endregion Test ToyClass if __name__ == "__main__":