Skip to content

Commit

Permalink
Use attrs.define instead of dataclasses
Browse files Browse the repository at this point in the history
- Split the changelog
- Use TypeGuard correctly
- Remove `transform` in favor of `__attrs_post_init__` to allow for a more personalized use of `validate`
- Remove `define` from the docs, change format in changelog
  • Loading branch information
LiteApplication committed Jun 7, 2024
1 parent 3d88332 commit f5b2767
Show file tree
Hide file tree
Showing 16 changed files with 209 additions and 212 deletions.
57 changes: 57 additions & 0 deletions changes/285.internal.1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
- Changed the way `Serializable` classes are handled:

Here is how a basic `Serializable` class looks like:

```python
@final
@dataclass
class ToyClass(Serializable):
"""Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`."""

a: int
b: str | int

@override
def __attrs_post_init__(self):
"""Initialize the object."""
if isinstance(self.b, int):
self.b = str(self.b)

super().__attrs_post_init__() # This will call validate()

@override
def serialize_to(self, buf: Buffer):
"""Write the object to a buffer."""
self.b = cast(str, self.b) # Handled by the __attrs_post_init__ method
buf.write_varint(self.a)
buf.write_utf(self.b)

@classmethod
@override
def deserialize(cls, buf: Buffer) -> ToyClass:
"""Deserialize the object from a buffer."""
a = buf.read_varint()
if a == 0:
raise ZeroDivisionError("a must be non-zero")
b = buf.read_utf()
return cls(a, b)

@override
def validate(self) -> None:
"""Validate the object's attributes."""
if self.a == 0:
raise ZeroDivisionError("a must be non-zero")
if len(self.b) > 10:
raise ValueError("b must be less than 10 characters")

```

The `Serializable` class implement the following methods:

- `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer.
- `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer.

And the following optional methods:

- `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid.
- `__attrs_post_init__() -> None`: Initializes the object. Call `super().__attrs_post_init__()` to validate the object.
61 changes: 4 additions & 57 deletions changes/285.internal.md → changes/285.internal.2.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,3 @@
- Changed the way `Serializable` classes are handled:

Here is how a basic `Serializable` class looks like:
```python
@final
@dataclass
class ToyClass(Serializable):
"""Toy class for testing demonstrating the use of gen_serializable_test on `Serializable`."""

a: int
b: str | int

@override
def serialize_to(self, buf: Buffer):
"""Write the object to a buffer."""
self.b = cast(str, self.b) # Handled by the transform method
buf.write_varint(self.a)
buf.write_utf(self.b)

@classmethod
@override
def deserialize(cls, buf: Buffer) -> ToyClass:
"""Deserialize the object from a buffer."""
a = buf.read_varint()
if a == 0:
raise ZeroDivisionError("a must be non-zero")
b = buf.read_utf()
return cls(a, b)

@override
def validate(self) -> None:
"""Validate the object's attributes."""
if self.a == 0:
raise ZeroDivisionError("a must be non-zero")
if (isinstance(self.b, int) and math.log10(self.b) > 10) or (isinstance(self.b, str) and len(self.b) > 10):
raise ValueError("b must be less than 10 characters")

@override
def transform(self) -> None:
"""Apply a transformation to the payload of the object."""
if isinstance(self.b, int):
self.b = str(self.b)
```


The `Serializable` class implement the following methods:
- `serialize_to(buf: Buffer) -> None`: Serializes the object to a buffer.
- `deserialize(buf: Buffer) -> Serializable`: Deserializes the object from a buffer.

And the following optional methods:
- `validate() -> None`: Validates the object's attributes, raising an exception if they are invalid.
- `transform() -> None`: Transforms the the object's attributes, this method is meant to convert types like you would in a classic `__init__`.
You can rely on this `validate` having been executed.

- Added a test generator for `Serializable` classes:

The `gen_serializable_test` function generates tests for `Serializable` classes. It takes the following arguments:
Expand All @@ -69,6 +15,7 @@ And the following optional methods:
- `(exception, bytes)`: The expected exception when deserializing the bytes and the bytes to deserialize.

The `gen_serializable_test` function generates a test class with the following tests:

```python
gen_serializable_test(
context=globals(),
Expand All @@ -80,7 +27,7 @@ gen_serializable_test(
((3, 1234567890), b"\x03\x0a1234567890"),
((0, "hello"), ZeroDivisionError("a must be non-zero")), # With an error message
((1, "hello world"), ValueError), # No error message
((1, 12345678900), ValueError),
((1, 12345678900), ValueError("b must be less than 10 .*")), # With an error message and regex
(ZeroDivisionError, b"\x00"),
(ZeroDivisionError, b"\x01\x05hello"),
(IOError, b"\x01"),
Expand All @@ -90,7 +37,7 @@ gen_serializable_test(

The generated test class will have the following tests:

```python
```python
class TestGenToyClass:
def test_serialization(self):
# 2 subtests for the cases 1 and 2
Expand All @@ -103,4 +50,4 @@ class TestGenToyClass:

def test_exceptions(self):
# 2 subtests for the cases 5 and 6
```
```
1 change: 1 addition & 0 deletions docs/api/internal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ should not be used externally**, as we do not guarantee their backwards compatib
may be introduced between patch versions without any warnings.

.. automodule:: mcproto.utils.abc
:exclude-members: define

.. autofunction:: tests.helpers.gen_serializable_test
..
Expand Down
18 changes: 10 additions & 8 deletions mcproto/packets/handshaking/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mcproto.buffer import Buffer
from mcproto.packets.packet import GameState, ServerBoundPacket
from mcproto.protocol.base_io import StructFormat
from mcproto.utils.abc import dataclass
from mcproto.utils.abc import define

__all__ = [
"NextState",
Expand All @@ -24,7 +24,7 @@ class NextState(IntEnum):


@final
@dataclass
@define
class Handshake(ServerBoundPacket):
"""Initializes connection between server and client. (Client -> Server).
Expand All @@ -44,10 +44,17 @@ class Handshake(ServerBoundPacket):
server_port: int
next_state: NextState | int

@override
def __attrs_post_init__(self) -> None:
if not isinstance(self.next_state, NextState):
self.next_state = NextState(self.next_state)

super().__attrs_post_init__()

@override
def serialize_to(self, buf: Buffer) -> None:
"""Serialize the packet."""
self.next_state = cast(NextState, self.next_state) # Handled by the transform method
self.next_state = cast(NextState, self.next_state) # Handled by the __attrs_post_init__ method
buf.write_varint(self.protocol_version)
buf.write_utf(self.server_address)
buf.write_value(StructFormat.USHORT, self.server_port)
Expand All @@ -69,8 +76,3 @@ def validate(self) -> None:
rev_lookup = {x.value: x for x in NextState.__members__.values()}
if self.next_state not in rev_lookup:
raise ValueError("No such next_state.")

@override
def transform(self) -> None:
"""Get the next state enum from the integer value."""
self.next_state = NextState(self.next_state)
30 changes: 16 additions & 14 deletions mcproto/packets/login/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket
from mcproto.types.chat import ChatMessage
from mcproto.types.uuid import UUID
from mcproto.utils.abc import dataclass
from mcproto.utils.abc import define

__all__ = [
"LoginDisconnect",
Expand All @@ -26,7 +26,7 @@


@final
@dataclass
@define
class LoginStart(ServerBoundPacket):
"""Packet from client asking to start login process. (Client -> Server).
Expand Down Expand Up @@ -56,7 +56,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:


@final
@dataclass
@define
class LoginEncryptionRequest(ClientBoundPacket):
"""Used by the server to ask the client to encrypt the login process. (Server -> Client).
Expand All @@ -74,6 +74,13 @@ class LoginEncryptionRequest(ClientBoundPacket):
verify_token: bytes
server_id: str | None = None

@override
def __attrs_post_init__(self) -> None:
if self.server_id is None:
self.server_id = " " * 20

super().__attrs_post_init__()

@override
def serialize_to(self, buf: Buffer) -> None:
self.server_id = cast(str, self.server_id)
Expand All @@ -96,14 +103,9 @@ def _deserialize(cls, buf: Buffer, /) -> Self:

return cls(server_id=server_id, public_key=public_key, verify_token=verify_token)

@override
def transform(self) -> None:
if self.server_id is None:
self.server_id = " " * 20


@final
@dataclass
@define
class LoginEncryptionResponse(ServerBoundPacket):
"""Response from the client to :class:`LoginEncryptionRequest` packet. (Client -> Server).
Expand Down Expand Up @@ -134,7 +136,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:


@final
@dataclass
@define
class LoginSuccess(ClientBoundPacket):
"""Sent by the server to denote a successful login. (Server -> Client).
Expand Down Expand Up @@ -164,7 +166,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:


@final
@dataclass
@define
class LoginDisconnect(ClientBoundPacket):
"""Sent by the server to kick a player while in the login state. (Server -> Client).
Expand All @@ -190,7 +192,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:


@final
@dataclass
@define
class LoginPluginRequest(ClientBoundPacket):
"""Sent by the server to implement a custom handshaking flow. (Server -> Client).
Expand Down Expand Up @@ -224,7 +226,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:


@final
@dataclass
@define
class LoginPluginResponse(ServerBoundPacket):
"""Response to LoginPluginRequest from client. (Client -> Server).
Expand Down Expand Up @@ -254,7 +256,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self:


@final
@dataclass
@define
class LoginSetCompression(ClientBoundPacket):
"""Sent by the server to specify whether to use compression on future packets or not (Server -> Client).
Expand Down
4 changes: 2 additions & 2 deletions mcproto/packets/status/ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from mcproto.buffer import Buffer
from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket
from mcproto.protocol.base_io import StructFormat
from mcproto.utils.abc import dataclass
from mcproto.utils.abc import define

__all__ = ["PingPong"]


@final
@dataclass
@define
class PingPong(ClientBoundPacket, ServerBoundPacket):
"""Ping request/Pong response (Server <-> Client).
Expand Down
6 changes: 3 additions & 3 deletions mcproto/packets/status/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

from mcproto.buffer import Buffer
from mcproto.packets.packet import ClientBoundPacket, GameState, ServerBoundPacket
from mcproto.utils.abc import dataclass
from mcproto.utils.abc import define

__all__ = ["StatusRequest", "StatusResponse"]


@final
@dataclass
@define
class StatusRequest(ServerBoundPacket):
"""Request from the client to get information on the server. (Client -> Server)."""

Expand All @@ -31,7 +31,7 @@ def _deserialize(cls, buf: Buffer, /) -> Self: # pragma: no cover, nothing to t


@final
@dataclass
@define
class StatusResponse(ClientBoundPacket):
"""Response from the server to requesting client with status data information. (Server -> Client).
Expand Down
4 changes: 2 additions & 2 deletions mcproto/types/abc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from mcproto.utils.abc import Serializable, dataclass
from mcproto.utils.abc import Serializable, define

__all__ = ["MCType", "dataclass"] # That way we can import it from mcproto.types.abc
__all__ = ["MCType", "define"] # That way we can import it from mcproto.types.abc


class MCType(Serializable):
Expand Down
6 changes: 4 additions & 2 deletions mcproto/types/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self, TypeAlias, override

from mcproto.buffer import Buffer
from mcproto.types.abc import MCType, dataclass
from mcproto.types.abc import MCType, define

__all__ = [
"ChatMessage",
Expand All @@ -33,13 +33,15 @@ class RawChatMessageDict(TypedDict, total=False):
RawChatMessage: TypeAlias = Union[RawChatMessageDict, "list[RawChatMessageDict]", str]


@dataclass
@final
@define
class ChatMessage(MCType):
"""Minecraft chat message representation."""

raw: RawChatMessage

__slots__ = ("raw",)

def as_dict(self) -> RawChatMessageDict:
"""Convert received ``raw`` into a stadard :class:`dict` form."""
if isinstance(self.raw, list):
Expand Down
Loading

0 comments on commit f5b2767

Please sign in to comment.