Skip to content

Commit

Permalink
Change the way Serializable classes work
Browse files Browse the repository at this point in the history
Provide a way to test serialization, deserialization, validation and deserialization errors easily.
This fixes Avoid repetition in tests for serialize+deserialize tests #64 and makes it easier to add new data types
  • Loading branch information
LiteApplication committed May 1, 2024
1 parent 062713c commit 09ff7b7
Show file tree
Hide file tree
Showing 18 changed files with 799 additions and 729 deletions.
90 changes: 90 additions & 0 deletions changes/273.internal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
- Changed the way `Serializable` classes are handled:

Here is how a basic `Serializable` class looks like:

@final
@dataclass
class ToyClass(Serializable):
"""Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`."""


# Attributes can be of any type
a: int
b: str

# dataclasses.field() can be used to specify additional metadata

def serialize_to(self, buf: Buffer):
"""Write the object to a buffer."""
buf.write_varint(self.a)
buf.write_utf(self.b)

@classmethod
def deserialize(cls, buf: Buffer) -> ToyClass:
"""Deserialize the object from a buffer."""
a = buf.read_varint()
if a == 0:
raise ZeroDivisionError("a must be non-zero")
b = buf.read_utf()
return cls(a, b)

def validate(self) -> None:
"""Validate the object's attributes."""
if self.a == 0:
raise ZeroDivisionError("a must be non-zero")
if len(self.b) > 10:
raise ValueError("b must be less than 10 characters")


The `Serializable` class must implement the following methods:

- `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer.
- `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer.
- `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid.

- Added a test generator for `Serializable` classes:

The `gen_serializable_test` function generates tests for `Serializable` classes. It takes the following arguments:

- `context`: The dictionary containing the context in which the generated test class will be placed (e.g. `globals()`).
> Dictionary updates must reflect in the context. This is the case for `globals()` but implementation-specific for `locals()`.
- `cls`: The `Serializable` class to generate tests for.
- `fields`: A list of fields where the test values will be placed.

> In the example above, the `ToyClass` class has two fields: `a` and `b`.
- `test_data`: A list of tuples containing either:
- `((field1_value, field2_value, ...), expected_bytes)`: The values of the fields and the expected serialized bytes. This needs to work both ways, i.e. `cls(field1_value, field2_value, ...) == cls.deserialize(expected_bytes).`
- `((field1_value, field2_value, ...), exception)`: The values of the fields and the expected exception when validating the object.
- `(exception, bytes)`: The expected exception when deserializing the bytes and the bytes to deserialize.

The `gen_serializable_test` function generates a test class with the following tests:

gen_serializable_test(
context=globals(),
cls=ToyClass,
fields=[("a", int), ("b", str)],
test_data=[
((1, "hello"), b"\x01\x05hello"),
((2, "world"), b"\x02\x05world"),
((0, "hello"), ZeroDivisionError),
((1, "hello world"), ValueError),
(ZeroDivisionError, b"\x00"),
(IOError, b"\x01"),
],
)

The generated test class will have the following tests:

class TestGenToyClass:
def test_serialization(self):
# 2 subtests for the cases 1 and 2

def test_deserialization(self):
# 2 subtests for the cases 1 and 2

def test_validation(self):
# 2 subtests for the cases 3 and 4

def test_exceptions(self):
# 2 subtests for the cases 5 and 6
64 changes: 32 additions & 32 deletions mcproto/packets/handshaking/handshake.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from enum import IntEnum
from typing import ClassVar, final
from typing import ClassVar, cast, final

from typing_extensions import Self, override

from mcproto.buffer import Buffer
from mcproto.packets.packet import GameState, ServerBoundPacket
from mcproto.protocol.base_io import StructFormat
from mcproto.utils.abc import dataclass

__all__ = [
"NextState",
Expand All @@ -23,49 +24,38 @@ class NextState(IntEnum):


@final
@dataclass
class Handshake(ServerBoundPacket):
"""Initializes connection between server and client. (Client -> Server)."""
"""Initializes connection between server and client. (Client -> Server).
Initialize the Handshake packet.
:param protocol_version: Protocol version number to be used.
:param server_address: The host/address the client is connecting to.
:param server_port: The port the client is connecting to.
:param next_state: The next state for the server to move into.
"""

PACKET_ID: ClassVar[int] = 0x00
GAME_STATE: ClassVar[GameState] = GameState.HANDSHAKING

__slots__ = ("protocol_version", "server_address", "server_port", "next_state")

def __init__(
self,
*,
protocol_version: int,
server_address: str,
server_port: int,
next_state: NextState | int,
):
"""Initialize the Handshake packet.
:param protocol_version: Protocol version number to be used.
:param server_address: The host/address the client is connecting to.
:param server_port: The port the client is connecting to.
:param next_state: The next state for the server to move into.
"""
if not isinstance(next_state, NextState): # next_state is int
rev_lookup = {x.value: x for x in NextState.__members__.values()}
try:
next_state = rev_lookup[next_state]
except KeyError as exc:
raise ValueError("No such next_state.") from exc
# Slots are already managed by the dataclass decorator automatically.
# __slots__ = ("protocol_version", "server_address", "server_port", "next_state")

self.protocol_version = protocol_version
self.server_address = server_address
self.server_port = server_port
self.next_state = next_state
# _ : dataclasses.KW_ONLY # Only available in Python 3.10+
protocol_version: int
server_address: str
server_port: int
next_state: NextState | int

@override
def serialize(self) -> Buffer:
buf = Buffer()
def serialize_to(self, buf: Buffer) -> None:
"""Serialize the packet."""
self.next_state = cast(NextState, self.next_state) # Handled by the validate method
buf.write_varint(self.protocol_version)
buf.write_utf(self.server_address)
buf.write_value(StructFormat.USHORT, self.server_port)
buf.write_varint(self.next_state.value)
return buf

@override
@classmethod
Expand All @@ -76,3 +66,13 @@ def _deserialize(cls, buf: Buffer, /) -> Self:
server_port=buf.read_value(StructFormat.USHORT),
next_state=buf.read_varint(),
)

@override
def validate(self) -> None:
"""Validate the packet."""
if not isinstance(self.next_state, NextState):
rev_lookup = {x.value: x for x in NextState.__members__.values()}
try:
self.next_state = rev_lookup[self.next_state]
except KeyError as exc:
raise ValueError("No such next_state.") from exc
Loading

0 comments on commit 09ff7b7

Please sign in to comment.