diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 29d27bfb..b1634b93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: args: [--sp, python/setup.cfg] files: python/xoscar - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.3.0 + rev: v1.4.1 hooks: - id: mypy additional_dependencies: [tokenize-rt==3.2.0] diff --git a/python/xoscar/api.py b/python/xoscar/api.py index e239a17c..5dcee3cd 100644 --- a/python/xoscar/api.py +++ b/python/xoscar/api.py @@ -17,12 +17,13 @@ from collections import defaultdict from numbers import Number -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from urllib.parse import urlparse +from .aio import AioFileObject from .backend import get_backend from .context import get_context -from .core import ActorRef, BufferRef, _Actor, _StatelessActor +from .core import ActorRef, BufferRef, FileObjectRef, _Actor, _StatelessActor if TYPE_CHECKING: from .backends.config import ActorPoolConfig @@ -180,25 +181,44 @@ def buffer_ref(address: str, buffer: Any) -> BufferRef: return ctx.buffer_ref(address, buffer) +def file_object_ref(address: str, fileobj: AioFileObject) -> FileObjectRef: + """ + Init file object ref according to address and aio file obj. + + Parameters + ---------- + address + The address of the file obj. + fileobj + Aio file object. + + Returns + ---------- + FileObjectRef obj. + """ + ctx = get_context() + return ctx.file_object_ref(address, fileobj) + + async def copy_to( - local_buffers: list, - remote_buffer_refs: List[BufferRef], + local_buffers_or_fileobjs: list, + remote_refs: List[Union[BufferRef, FileObjectRef]], block_size: Optional[int] = None, ): """ - Copy data from local buffers to remote buffers. + Copy data from local buffers to remote buffers or copy local file objects to remote file objects. Parameters ---------- - local_buffers - Local buffers. - remote_buffer_refs - Remote buffer refs. + local_buffers_or_fileobjs + Local buffers or file objects. + remote_refs + Remote buffer refs or file object refs. block_size Transfer block size when non-ucx """ ctx = get_context() - return await ctx.copy_to(local_buffers, remote_buffer_refs, block_size) + return await ctx.copy_to(local_buffers_or_fileobjs, remote_refs, block_size) async def wait_actor_pool_recovered(address: str, main_pool_address: str | None = None): diff --git a/python/xoscar/backends/context.py b/python/xoscar/backends/context.py index 01a5d859..5f981b64 100644 --- a/python/xoscar/backends/context.py +++ b/python/xoscar/backends/context.py @@ -20,9 +20,10 @@ from typing import Any, List, Optional, Tuple, Type, Union from .._utils import create_actor_ref, to_binary +from ..aio import AioFileObject from ..api import Actor from ..context import BaseActorContext -from ..core import ActorRef, BufferRef, create_local_actor_ref +from ..core import ActorRef, BufferRef, FileObjectRef, create_local_actor_ref from ..debug import debug_async_timeout, detect_cycle_send from ..errors import CannotCancelTask from ..utils import dataslots @@ -36,6 +37,7 @@ ControlMessage, ControlMessageType, CopyToBuffersMessage, + CopyToFileObjectsMessage, CreateActorMessage, DestroyActorMessage, ErrorMessage, @@ -281,9 +283,13 @@ def _gen_switch_to_copy_to_control_message(content: Any): ) @staticmethod - def _gen_copy_to_message(content: Any): + def _gen_copy_to_buffers_message(content: Any): return CopyToBuffersMessage(message_id=new_message_id(), content=content) # type: ignore + @staticmethod + def _gen_copy_to_fileobjs_message(content: Any): + return CopyToFileObjectsMessage(message_id=new_message_id(), content=content) # type: ignore + async def _get_copy_to_client(self, router, address) -> Client: client = await self._caller.get_client(router, address) if isinstance(client, DummyClient) or hasattr(client, "send_buffers"): @@ -302,28 +308,19 @@ async def _get_copy_to_client(self, router, address) -> Client: else: return await self._caller.get_client_via_type(router, address, client_type) - async def copy_to( + async def _get_client(self, address: str) -> Client: + router = Router.get_instance() + assert router is not None, "`copy_to` can only be used inside pools" + return await self._get_copy_to_client(router, address) + + async def copy_to_buffers( self, local_buffers: list, remote_buffer_refs: List[BufferRef], block_size: Optional[int] = None, ): - assert ( - len({ref.address for ref in remote_buffer_refs}) == 1 - ), "remote buffers for `copy_via_buffers` can support only 1 destination" - assert len(local_buffers) == len(remote_buffer_refs), ( - f"Buffers from local and remote must have same size, " - f"local: {len(local_buffers)}, remote: {len(remote_buffer_refs)}" - ) - if block_size is not None: - assert ( - block_size > 0 - ), f"`block_size` option must be greater than 0, current value: {block_size}." - - router = Router.get_instance() - assert router is not None, "`copy_to` can only be used inside pools" address = remote_buffer_refs[0].address - client = await self._get_copy_to_client(router, address) + client = await self._get_client(address) if isinstance(client, UCXClient): message = [(buf.address, buf.uid) for buf in remote_buffer_refs] await self._call_send_buffers( @@ -352,7 +349,7 @@ async def copy_to( (r_buf.address, r_buf.uid, last_start, remain, l_buf[:remain]) ) await self._call_with_client( - client, self._gen_copy_to_message(one_block_data) + client, self._gen_copy_to_buffers_message(one_block_data) ) one_block_data = [] current_buf_size = 0 @@ -367,5 +364,37 @@ async def copy_to( if one_block_data: await self._call_with_client( - client, self._gen_copy_to_message(one_block_data) + client, self._gen_copy_to_buffers_message(one_block_data) ) + + async def copy_to_fileobjs( + self, + local_fileobjs: List[AioFileObject], + remote_fileobj_refs: List[FileObjectRef], + block_size: Optional[int] = None, + ): + address = remote_fileobj_refs[0].address + client = await self._get_client(address) + block_size = block_size or DEFAULT_TRANSFER_BLOCK_SIZE + one_block_data = [] + current_file_size = 0 + for file_obj, remote_ref in zip(local_fileobjs, remote_fileobj_refs): + while True: + file_data = await file_obj.read(block_size) # type: ignore + if file_data: + one_block_data.append( + (remote_ref.address, remote_ref.uid, file_data) + ) + current_file_size += len(file_data) + if current_file_size >= block_size: + message = self._gen_copy_to_fileobjs_message(one_block_data) + await self._call_with_client(client, message) + one_block_data.clear() + current_file_size = 0 + else: + break + + if current_file_size > 0: + message = self._gen_copy_to_fileobjs_message(one_block_data) + await self._call_with_client(client, message) + one_block_data.clear() diff --git a/python/xoscar/backends/message.pyi b/python/xoscar/backends/message.pyi index f81f6433..0d94c879 100644 --- a/python/xoscar/backends/message.pyi +++ b/python/xoscar/backends/message.pyi @@ -35,6 +35,7 @@ class MessageType(Enum): tell = 8 cancel = 9 copy_to_buffers = 10 + copy_to_fileobjs = 11 class ControlMessageType(Enum): stop = 0 @@ -75,6 +76,9 @@ class CopyToBuffersMessage(_MessageBase): message_trace: list | None = None, ): ... +class CopyToFileObjectsMessage(CopyToBuffersMessage): + message_type = MessageType.copy_to_fileobjs + class ControlMessage(_MessageBase): message_type = MessageType.control diff --git a/python/xoscar/backends/message.pyx b/python/xoscar/backends/message.pyx index 5de3b0d3..71db25e4 100644 --- a/python/xoscar/backends/message.pyx +++ b/python/xoscar/backends/message.pyx @@ -43,6 +43,7 @@ class MessageType(Enum): tell = 8 cancel = 9 copy_to_buffers = 10 + copy_to_fileobjs = 11 class ControlMessageType(Enum): @@ -520,6 +521,10 @@ cdef class CopyToBuffersMessage(_MessageBase): self.content = subs[0] +cdef class CopyToFileObjectsMessage(CopyToBuffersMessage): + message_type = MessageType.copy_to_fileobjs + + cdef dict _message_type_to_message_cls = { MessageType.control.value: ControlMessage, MessageType.result.value: ResultMessage, @@ -532,6 +537,7 @@ cdef dict _message_type_to_message_cls = { MessageType.tell.value: TellMessage, MessageType.cancel.value: CancelMessage, MessageType.copy_to_buffers.value: CopyToBuffersMessage, + MessageType.copy_to_fileobjs.value: CopyToFileObjectsMessage } diff --git a/python/xoscar/backends/pool.py b/python/xoscar/backends/pool.py index 32b62403..b4384e9a 100644 --- a/python/xoscar/backends/pool.py +++ b/python/xoscar/backends/pool.py @@ -29,7 +29,7 @@ from .._utils import TypeDispatcher, create_actor_ref, to_binary from ..api import Actor -from ..core import ActorRef, BufferRef, register_local_pool +from ..core import ActorRef, BufferRef, FileObjectRef, register_local_pool from ..debug import debug_async_timeout, record_message_trace from ..entrypoints import init_extension_entrypoints from ..errors import ( @@ -123,7 +123,8 @@ def _register_message_handler(pool_type: Type["AbstractActorPool"]): (MessageType.tell, pool_type.tell), (MessageType.cancel, pool_type.cancel), (MessageType.control, pool_type.handle_control_command), - (MessageType.copy_to_buffers, pool_type.handle_copy_to_message), + (MessageType.copy_to_buffers, pool_type.handle_copy_to_buffers_message), + (MessageType.copy_to_fileobjs, pool_type.handle_copy_to_fileobjs_message), ]: pool_type._message_handler[message_type] = handler # type: ignore return pool_type @@ -500,12 +501,18 @@ async def stop(self): finally: self._stopped.set() - async def handle_copy_to_message(self, message) -> ResultMessage: + async def handle_copy_to_buffers_message(self, message) -> ResultMessage: for addr, uid, start, _len, data in message.content: buffer = BufferRef.get_buffer(BufferRef(addr, uid)) buffer[start : start + _len] = data return ResultMessage(message_id=message.message_id, result=True) + async def handle_copy_to_fileobjs_message(self, message) -> ResultMessage: + for addr, uid, data in message.content: + file_obj = FileObjectRef.get_local_file_object(FileObjectRef(addr, uid)) + await file_obj.write(data) + return ResultMessage(message_id=message.message_id, result=True) + @property def stopped(self) -> bool: return self._stopped.is_set() diff --git a/python/xoscar/backends/test/tests/test_transfer.py b/python/xoscar/backends/test/tests/test_transfer.py index 7e641e5a..1656f2fc 100644 --- a/python/xoscar/backends/test/tests/test_transfer.py +++ b/python/xoscar/backends/test/tests/test_transfer.py @@ -14,18 +14,21 @@ # limitations under the License. import asyncio import os +import shutil import sys +import tempfile from typing import List, Optional import numpy as np import pytest from .... import Actor, ActorRefType -from ....api import actor_ref, buffer_ref, copy_to +from ....aio import AioFileObject +from ....api import actor_ref, buffer_ref, copy_to, file_object_ref from ....backends.allocate_strategy import ProcessIndex from ....backends.indigen.pool import MainActorPool from ....context import get_context -from ....core import BufferRef +from ....core import BufferRef, FileObjectRef from ....tests.core import require_cupy from ....utils import lazy_import from ...pool import create_actor_pool @@ -203,3 +206,92 @@ async def tests_gpu_copy(scheme): if "ucx" == scheme: await _copy_test(None, "ucx", False) await _copy_test("ucx", None, False) + + +class FileobjTransferActor(Actor): + def __init__(self): + self._fileobjs = [] + + async def create_file_objects(self, names: List[str]) -> List[FileObjectRef]: + refs = [] + for name in names: + fobj = open(name, "w+b") + afobj = AioFileObject(fobj) + self._fileobjs.append(afobj) + refs.append(file_object_ref(self.address, afobj)) + return refs + + async def close(self): + for fobj in self._fileobjs: + assert await fobj.tell() > 0 + await fobj.close() + + async def copy_data( + self, + ref: ActorRefType["FileobjTransferActor"], + names1: List[str], + names2: List[str], + sizes: List[int], + ): + fobjs = [] + for name, size in zip(names1, sizes): + fobj = open(name, "w+b") + fobj.write(np.random.bytes(size)) + fobj.seek(0) + fobjs.append(AioFileObject(fobj)) + + ref = await actor_ref(ref) + file_obj_refs = await ref.create_file_objects(names2) + await copy_to(fobjs, file_obj_refs) + _ = [await f.close() for f in fobjs] # type: ignore + await ref.close() + + for n1, n2 in zip(names1, names2): + with open(n1, "rb") as f1, open(n2, "rb") as f2: + b1 = f1.read() + b2 = f2.read() + assert b1 == b2 + + +@pytest.mark.asyncio +async def test_copy_to_file_objects(): + start_method = ( + os.environ.get("POOL_START_METHOD", "forkserver") + if sys.platform != "win32" + else None + ) + pool = await create_actor_pool( + "127.0.0.1", + pool_cls=MainActorPool, + n_process=2, + subprocess_start_method=start_method, + ) + + d = tempfile.mkdtemp() + async with pool: + ctx = get_context() + + # actor on main pool + actor_ref1 = await ctx.create_actor( + FileobjTransferActor, + uid="test-1", + address=pool.external_address, + allocate_strategy=ProcessIndex(1), + ) + actor_ref2 = await ctx.create_actor( + FileobjTransferActor, + uid="test-2", + address=pool.external_address, + allocate_strategy=ProcessIndex(2), + ) + sizes = [10 * 1024**2, 3 * 1024**2, 0.5 * 1024**2, 0.25 * 1024**2] + names = [] + for _ in range(2 * len(sizes)): + _, p = tempfile.mkstemp(dir=d) + names.append(p) + + await actor_ref1.copy_data(actor_ref2, names[::2], names[1::2], sizes=sizes) + try: + shutil.rmtree(d) + except PermissionError: + pass diff --git a/python/xoscar/context.pyx b/python/xoscar/context.pyx index 673137d1..5d8d0e10 100644 --- a/python/xoscar/context.pyx +++ b/python/xoscar/context.pyx @@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from urllib.parse import urlparse from ._utils cimport new_actor_id, new_random_id -from .core cimport ActorRef, BufferRef +from .core cimport ActorRef, BufferRef, FileObjectRef cdef dict _backend_context_cls = dict() @@ -196,7 +196,24 @@ cdef class BaseActorContext: """ return BufferRef.create(buf, address, new_random_id(32)) - async def copy_to(self, local_buffers: List[bytes], remote_buffer_refs: List[BufferRef], block_size: Optional[int] = None): + def file_object_ref(self, str address, object file_object) -> FileObjectRef: + """ + Create a reference to an aio file object + + Parameters + ---------- + address + address of the actor pool + file_object + aio file object + + Returns + ------- + FileObjectRef + """ + return FileObjectRef.create(file_object, address, new_random_id(32)) + + async def copy_to_buffers(self, local_buffers: List, remote_buffer_refs: List[BufferRef], block_size: Optional[int] = None): """ Copy local buffers to remote buffers. Parameters @@ -210,6 +227,20 @@ cdef class BaseActorContext: """ raise NotImplementedError + async def copy_to_fileobjs(self, local_fileobjs: list, remote_fileobj_refs: List[FileObjectRef], block_size: Optional[int] = None): + """ + Copy local file objs to remote file objs. + Parameters + ---------- + local_fileobjs + Local file objs. + remote_fileobj_refs + Remote file object refs + block_size + Transfer block size when non-ucx + """ + raise NotImplementedError + cdef class ClientActorContext(BaseActorContext): """ @@ -294,12 +325,31 @@ cdef class ClientActorContext(BaseActorContext): context = self._get_backend_context(address) return context.buffer_ref(address, buf) - def copy_to(self, local_buffers: list, remote_buffer_refs: List[BufferRef], block_size: Optional[int] = None): - if len(local_buffers) == 0 or len(remote_buffer_refs) == 0: - raise ValueError("Nothing to transfer since the length of `local_buffers` or `remote_buffer_refs` is 0.") - address = remote_buffer_refs[0].address + def file_object_ref(self, str address, object file_object) -> FileObjectRef: context = self._get_backend_context(address) - return context.copy_to(local_buffers, remote_buffer_refs, block_size) + return context.file_object_ref(address, file_object) + + def copy_to(self, local_buffers_or_fileobjs: list, remote_refs: List[Union[BufferRef, FileObjectRef]], block_size: Optional[int] = None): + if len(local_buffers_or_fileobjs) == 0 or len(remote_refs) == 0: + raise ValueError("Nothing to transfer since the length of `local_buffers_or_fileobjs` or `remote_refs` is 0.") + assert ( + len({ref.address for ref in remote_refs}) == 1 + ), "remote_refs for `copy_to` can support only 1 destination" + assert len(local_buffers_or_fileobjs) == len(remote_refs), ( + f"Buffers or fileobjs from local and remote must have same size, " + f"local: {len(local_buffers_or_fileobjs)}, remote: {len(remote_refs)}" + ) + if block_size is not None: + assert ( + block_size > 0 + ), f"`block_size` option must be greater than 0, current value: {block_size}." + remote_ref = remote_refs[0] + address = remote_ref.address + context = self._get_backend_context(address) + if isinstance(remote_ref, BufferRef): + return context.copy_to_buffers(local_buffers_or_fileobjs, remote_refs, block_size) + else: + return context.copy_to_fileobjs(local_buffers_or_fileobjs, remote_refs, block_size) def register_backend_context(scheme, cls): diff --git a/python/xoscar/core.pxd b/python/xoscar/core.pxd index af40f2b3..b99c1945 100644 --- a/python/xoscar/core.pxd +++ b/python/xoscar/core.pxd @@ -31,6 +31,11 @@ cdef class BufferRef: cdef public bytes uid +cdef class FileObjectRef: + cdef public str address + cdef public bytes uid + + cdef class _BaseActor: cdef object __weakref__ cdef str _address diff --git a/python/xoscar/core.pyx b/python/xoscar/core.pyx index 44b2f409..6c06eaf6 100644 --- a/python/xoscar/core.pyx +++ b/python/xoscar/core.pyx @@ -22,6 +22,7 @@ from typing import Any, AsyncGenerator cimport cython +from .aio import AioFileObject from .context cimport get_context from .errors import ActorNotExist, Return @@ -585,4 +586,42 @@ cdef class BufferRef: return self.address == other.address and self.uid == other.uid def __repr__(self): - return f'BufferRef(uid={self.uid}, address={self.address})' + return f'BufferRef(uid={self.uid.hex()}, address={self.address})' + + +cdef class FileObjectRef: + """ + Reference of a file obj + """ + _ref_to_fileobjs = weakref.WeakValueDictionary() + + def __init__(self, str address, bytes uid): + self.uid = uid + self.address = address + + @classmethod + def create(cls, fileobj: AioFileObject, address: str, uid: bytes) -> "FileObjectRef": + ref = FileObjectRef(address, uid) + cls._ref_to_fileobjs[ref] = fileobj + return ref + + @classmethod + def get_local_file_object(cls, ref: "FileObjectRef") -> AioFileObject: + return cls._ref_to_fileobjs[ref] + + def __getstate__(self): + return self.uid, self.address + + def __setstate__(self, state): + self.uid, self.address = state + + def __hash__(self): + return hash((self.address, self.uid)) + + def __eq__(self, other): + if type(other) != FileObjectRef: + return False + return self.address == other.address and self.uid == other.uid + + def __repr__(self): + return f'FileObjectRef(uid={self.uid.hex()}, address={self.address})'