Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Serializable (no NBT) #273

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions changes/273.internal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
- Changed the way `Serializable` classes are handled:

Here is how a basic `Serializable` class looks like:

@final
@dataclass
class ToyClass(Serializable):
"""Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`."""


# Attributes can be of any type
a: int
b: str

# dataclasses.field() can be used to specify additional metadata

def serialize_to(self, buf: Buffer):
"""Write the object to a buffer."""
buf.write_varint(self.a)
buf.write_utf(self.b)

@classmethod
def deserialize(cls, buf: Buffer) -> ToyClass:
"""Deserialize the object from a buffer."""
a = buf.read_varint()
if a == 0:
raise ZeroDivisionError("a must be non-zero")
b = buf.read_utf()
return cls(a, b)

def validate(self) -> None:
"""Validate the object's attributes."""
if self.a == 0:
raise ZeroDivisionError("a must be non-zero")
if len(self.b) > 10:
raise ValueError("b must be less than 10 characters")


The `Serializable` class must implement the following methods:

- `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer.
- `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer.
- `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid.

- Added a test generator for `Serializable` classes:

The `gen_serializable_test` function generates tests for `Serializable` classes. It takes the following arguments:

- `context`: The dictionary containing the context in which the generated test class will be placed (e.g. `globals()`).
> Dictionary updates must reflect in the context. This is the case for `globals()` but implementation-specific for `locals()`.
- `cls`: The `Serializable` class to generate tests for.
- `fields`: A list of fields where the test values will be placed.

> In the example above, the `ToyClass` class has two fields: `a` and `b`.

- `test_data`: A list of tuples containing either:
- `((field1_value, field2_value, ...), expected_bytes)`: The values of the fields and the expected serialized bytes. This needs to work both ways, i.e. `cls(field1_value, field2_value, ...) == cls.deserialize(expected_bytes).`
- `((field1_value, field2_value, ...), exception)`: The values of the fields and the expected exception when validating the object.
- `(exception, bytes)`: The expected exception when deserializing the bytes and the bytes to deserialize.

The `gen_serializable_test` function generates a test class with the following tests:

gen_serializable_test(
context=globals(),
cls=ToyClass,
fields=[("a", int), ("b", str)],
test_data=[
((1, "hello"), b"\x01\x05hello"),
((2, "world"), b"\x02\x05world"),
((0, "hello"), ZeroDivisionError),
((1, "hello world"), ValueError),
(ZeroDivisionError, b"\x00"),
(IOError, b"\x01"),
],
)

The generated test class will have the following tests:

class TestGenToyClass:
def test_serialization(self):
# 2 subtests for the cases 1 and 2

def test_deserialization(self):
# 2 subtests for the cases 1 and 2

def test_validation(self):
# 2 subtests for the cases 3 and 4

def test_exceptions(self):
# 2 subtests for the cases 5 and 6
64 changes: 32 additions & 32 deletions mcproto/packets/handshaking/handshake.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from enum import IntEnum
from typing import ClassVar, final
from typing import ClassVar, cast, final

from typing_extensions import Self, override

from mcproto.buffer import Buffer
from mcproto.packets.packet import GameState, ServerBoundPacket
from mcproto.protocol.base_io import StructFormat
from mcproto.utils.abc import dataclass

__all__ = [
"NextState",
Expand All @@ -23,49 +24,38 @@ class NextState(IntEnum):


@final
@dataclass
class Handshake(ServerBoundPacket):
"""Initializes connection between server and client. (Client -> Server)."""
"""Initializes connection between server and client. (Client -> Server).

Initialize the Handshake packet.

:param protocol_version: Protocol version number to be used.
:param server_address: The host/address the client is connecting to.
:param server_port: The port the client is connecting to.
:param next_state: The next state for the server to move into.
"""

PACKET_ID: ClassVar[int] = 0x00
GAME_STATE: ClassVar[GameState] = GameState.HANDSHAKING

__slots__ = ("protocol_version", "server_address", "server_port", "next_state")

def __init__(
self,
*,
protocol_version: int,
server_address: str,
server_port: int,
next_state: NextState | int,
):
"""Initialize the Handshake packet.

:param protocol_version: Protocol version number to be used.
:param server_address: The host/address the client is connecting to.
:param server_port: The port the client is connecting to.
:param next_state: The next state for the server to move into.
"""
if not isinstance(next_state, NextState): # next_state is int
rev_lookup = {x.value: x for x in NextState.__members__.values()}
try:
next_state = rev_lookup[next_state]
except KeyError as exc:
raise ValueError("No such next_state.") from exc
# Slots are already managed by the dataclass decorator automatically.
# __slots__ = ("protocol_version", "server_address", "server_port", "next_state")

self.protocol_version = protocol_version
self.server_address = server_address
self.server_port = server_port
self.next_state = next_state
# _ : dataclasses.KW_ONLY # Only available in Python 3.10+
protocol_version: int
server_address: str
server_port: int
next_state: NextState | int

@override
def serialize(self) -> Buffer:
buf = Buffer()
def serialize_to(self, buf: Buffer) -> None:
"""Serialize the packet."""
self.next_state = cast(NextState, self.next_state) # Handled by the validate method
buf.write_varint(self.protocol_version)
buf.write_utf(self.server_address)
buf.write_value(StructFormat.USHORT, self.server_port)
buf.write_varint(self.next_state.value)
return buf

@override
@classmethod
Expand All @@ -76,3 +66,13 @@ def _deserialize(cls, buf: Buffer, /) -> Self:
server_port=buf.read_value(StructFormat.USHORT),
next_state=buf.read_varint(),
)

@override
def validate(self) -> None:
"""Validate the packet."""
if not isinstance(self.next_state, NextState):
rev_lookup = {x.value: x for x in NextState.__members__.values()}
try:
self.next_state = rev_lookup[self.next_state]
except KeyError as exc:
raise ValueError("No such next_state.") from exc
Loading
Loading