diff --git a/doc/reference/serialization.rst b/doc/reference/serialization.rst index 6793d9548..dbf18e4cc 100644 --- a/doc/reference/serialization.rst +++ b/doc/reference/serialization.rst @@ -114,7 +114,10 @@ A ``Serializer`` can be extended with additional data types by calling ``seriali "varlenI", "4 + ?", "str (length < 4294967295)" "doublevarlenH", "2 + ?", "str (length ? < 65356)" "payload", "2 + ?", "Serializable" - + "payload-list", "?", "[Serializable]" + "arrayH-?", "2 + ? * 1", "[bool]" + "arrayH-q", "2 + ? * 8", "[int]" + "arrayH-d", "2 + ? * 8", "[float]" Some of these data types represent common usage of serializable classes: diff --git a/ipv8/messaging/payload_dataclass.py b/ipv8/messaging/payload_dataclass.py index 00e102e0f..5d13629ad 100644 --- a/ipv8/messaging/payload_dataclass.py +++ b/ipv8/messaging/payload_dataclass.py @@ -17,11 +17,13 @@ def type_from_format(fmt: str) -> TypeVar: return out -def type_map(t: Type) -> FormatListType: +def type_map(t: Type) -> FormatListType: # noqa: PLR0911 if t is bool: return "?" if t is int: return "q" + if t is float: + return "d" if t is bytes: return "varlenH" if t is str: @@ -29,11 +31,15 @@ def type_map(t: Type) -> FormatListType: if isinstance(t, TypeVar): return t.__name__ if getattr(t, '__origin__', None) in (tuple, list, set): - return [t.__args__[0]] + fmt = t.__args__[0] + if issubclass(fmt, Serializable): + return [fmt] + return f"arrayH-{type_map(t.__args__[0])}" if isinstance(t, (tuple, list, set)) or Serializable in getattr(t, "mro", list)(): return cast(Type[Serializable], t) raise NotImplementedError(t, " unknown") + def dataclass(cls: type | None = None, *, # noqa: PLR0913 init: bool = True, repr: bool = True, # noqa: A002 diff --git a/ipv8/messaging/serialization.py b/ipv8/messaging/serialization.py index b2a9a90e7..7d0e15756 100644 --- a/ipv8/messaging/serialization.py +++ b/ipv8/messaging/serialization.py @@ -3,6 +3,7 @@ import abc import socket import typing +from array import array from binascii import hexlify from contextlib import suppress from struct import Struct, pack, unpack_from @@ -331,6 +332,40 @@ def unpack(self, data: bytes, offset: int, unpack_list: list, *args: object) -> return offset +class DefaultArray(Packer): + """ + A format known to the ``array`` module (like 'I', 'B', etc.). + + Also adds support for '?'. + """ + + def __init__(self, format_str: str, length_format: str) -> None: + """ + Create a new packer for the given ``array`` format string. + """ + self.format_str = format_str + self.real_format_str = "B" if format_str == "?" else format_str + self.length_format = length_format + self.length_size = Struct(length_format).size + self.base = array(self.real_format_str).itemsize + + def pack(self, data: list) -> bytes: + """ + Pack a list of items by forwarding them to ``array``. + """ + return pack(self.length_format, len(data)) + array(self.real_format_str, data).tobytes() + + def unpack(self, data: bytes, offset: int, unpack_list: list, *args: object) -> int: + """ + Unpack a list of items from the known ``array`` format. + """ + str_length = unpack_from(self.length_format, data, offset)[0] * self.base + a = array(self.real_format_str) + a.frombytes(data[offset + self.length_size: offset + self.length_size + str_length]) + unpack_list.append([bool(b) for b in a] if self.format_str == "?" else list(a)) + return offset + self.length_size + str_length + + class DefaultStruct(Packer): """ A format known to the ``struct`` module (like 'I', '20s', etc.). @@ -351,7 +386,7 @@ def pack(self, *data: list) -> bytes: def unpack(self, data: bytes, offset: int, unpack_list: list, *args: object) -> int: """ - Unpack a list of items from a the known ``struct`` format. + Unpack a list of items from the known ``struct`` format. """ result = unpack_from(self.format_str, data, offset) unpack_list.append(result if len(result) > 1 else result[0]) @@ -407,7 +442,10 @@ def __init__(self) -> None: 'varlenI': VarLen('>I'), 'doublevarlenH': VarLen('>H'), 'payload': NestedPayload(self), - 'payload-list': ListOf(NestedPayload(self)) + 'payload-list': ListOf(NestedPayload(self)), + 'arrayH-?': DefaultArray("?", "H"), + 'arrayH-q': DefaultArray("q", "H"), + 'arrayH-d': DefaultArray("d", "H"), } def get_available_formats(self) -> list[str]: diff --git a/ipv8/test/messaging/test_payload_dataclass.py b/ipv8/test/messaging/test_payload_dataclass.py index 182f5e44f..dcd711a00 100644 --- a/ipv8/test/messaging/test_payload_dataclass.py +++ b/ipv8/test/messaging/test_payload_dataclass.py @@ -75,6 +75,21 @@ class NestedListType: a: List[NativeInt] # Backward compatibility: Python >= 3.9 can use ``list[NativeInt]`` +@dataclass +class ListIntType: + """ + A single list of integers. + """ + + a: List[int] + +@dataclass +class ListBoolType: + """ + A single list of booleans. + """ + + a: List[bool] @ogdataclass class Unknown: @@ -161,6 +176,8 @@ class Everything: d: EverythingItem e: List[EverythingItem] # Backward compatibility: Python >= 3.9 can use ``list[EverythingItem]`` f: str + g: List[int] + h: List[bool] class TestDataclassPayload(TestBase): @@ -348,6 +365,26 @@ def test_nested_payload(self) -> None: self.assertEqual(payload.a, NativeInt(42)) self.assertEqual(deserialized.a, NativeInt(42)) + def test_native_intlist_payload(self) -> None: + """ + Check if a list of native types works correctly. + """ + payload = ListIntType([1, 2]) + deserialized = self._pack_and_unpack(ListIntType, payload) + + self.assertListEqual(payload.a, [1, 2]) + self.assertListEqual(deserialized.a, [1, 2]) + + def test_native_boollist_payload(self) -> None: + """ + Check if a list of native types works correctly. + """ + payload = ListBoolType([True, False]) + deserialized = self._pack_and_unpack(ListBoolType, payload) + + self.assertListEqual(payload.a, [True, False]) + self.assertListEqual(deserialized.a, [True, False]) + def test_nestedlist_empty_payload(self) -> None: """ Check if an empty list of nested payloads works correctly. @@ -416,7 +453,9 @@ def test_everything(self) -> None: b'1337', EverythingItem(True), [EverythingItem(False), EverythingItem(True)], - "hi") + "hi", + [3, 4], + [False, True]) self.assertTrue(is_dataclass(a)) @@ -439,3 +478,9 @@ def test_everything(self) -> None: self.assertEqual(a.f, "hi") self.assertEqual(r.f, "hi") + + self.assertEqual(a.g, [3, 4]) + self.assertEqual(r.g, [3, 4]) + + self.assertEqual(a.h, [False, True]) + self.assertEqual(r.h, [False, True])