diff --git a/changes/2754.fix.md b/changes/2754.fix.md new file mode 100644 index 0000000000..e7d6d77916 --- /dev/null +++ b/changes/2754.fix.md @@ -0,0 +1 @@ +Correct `msgpack` deserialization of `ResourceSlot`. \ No newline at end of file diff --git a/src/ai/backend/common/msgpack.py b/src/ai/backend/common/msgpack.py index d8a2043ba3..9a435b8f4a 100644 --- a/src/ai/backend/common/msgpack.py +++ b/src/ai/backend/common/msgpack.py @@ -14,7 +14,7 @@ import msgpack as _msgpack import temporenc -from .types import BinarySize +from .types import BinarySize, ResourceSlot __all__ = ("packb", "unpackb") @@ -27,6 +27,7 @@ class ExtTypes(enum.IntEnum): POSIX_PATH = 4 PURE_POSIX_PATH = 5 ENUM = 6 + RESOURCE_SLOT = 8 BACKENDAI_BINARY_SIZE = 16 @@ -46,6 +47,8 @@ def _default(obj: object) -> Any: return _msgpack.ExtType(ExtTypes.POSIX_PATH, os.fsencode(obj)) case PurePosixPath(): return _msgpack.ExtType(ExtTypes.PURE_POSIX_PATH, os.fsencode(obj)) + case ResourceSlot(): + return _msgpack.ExtType(ExtTypes.RESOURCE_SLOT, pickle.dumps(obj, protocol=5)) case enum.Enum(): return _msgpack.ExtType(ExtTypes.ENUM, pickle.dumps(obj, protocol=5)) raise TypeError(f"Unknown type: {obj!r} ({type(obj)})") @@ -65,6 +68,8 @@ def _ext_hook(code: int, data: bytes) -> Any: return PurePosixPath(os.fsdecode(data)) case ExtTypes.ENUM: return pickle.loads(data) + case ExtTypes.RESOURCE_SLOT: + return pickle.loads(data) case ExtTypes.BACKENDAI_BINARY_SIZE: return pickle.loads(data) return _msgpack.ExtType(code, data) diff --git a/tests/common/test_msgpack.py b/tests/common/test_msgpack.py index 3b4b811296..74ae3829c4 100644 --- a/tests/common/test_msgpack.py +++ b/tests/common/test_msgpack.py @@ -6,7 +6,7 @@ from dateutil.tz import gettz, tzutc from ai.backend.common import msgpack -from ai.backend.common.types import BinarySize, SlotTypes +from ai.backend.common.types import BinarySize, ResourceSlot, SlotTypes def test_msgpack_with_unicode(): @@ -125,3 +125,20 @@ def test_msgpack_posixpath(): unpacked = msgpack.unpackb(packed) assert isinstance(unpacked["path"], PosixPath) assert unpacked["path"] == path + + +def test_msgpack_resource_slot(): + resource_slot = ResourceSlot({"cpu": 1, "mem": 1024}) + packed = msgpack.packb(resource_slot) + unpacked = msgpack.unpackb(packed) + assert unpacked == resource_slot + + resource_slot = ResourceSlot({"cpu": 2, "mem": Decimal(1024**5)}) + packed = msgpack.packb(resource_slot) + unpacked = msgpack.unpackb(packed) + assert unpacked == resource_slot + + resource_slot = ResourceSlot({"cpu": 3, "mem": "1125899906842624"}) + packed = msgpack.packb(resource_slot) + unpacked = msgpack.unpackb(packed) + assert unpacked == resource_slot