Skip to content

Commit

Permalink
ENH: copy_to supports aio files (#32)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
ChengjieLi28 and mergify[bot] authored Jul 5, 2023
1 parent 9537f54 commit 1ea9836
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
args: [--sp, python/setup.cfg]
files: python/xoscar
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
rev: v1.4.1
hooks:
- id: mypy
additional_dependencies: [tokenize-rt==3.2.0]
Expand Down
40 changes: 30 additions & 10 deletions python/xoscar/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

from collections import defaultdict
from numbers import Number
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse

from .aio import AioFileObject
from .backend import get_backend
from .context import get_context
from .core import ActorRef, BufferRef, _Actor, _StatelessActor
from .core import ActorRef, BufferRef, FileObjectRef, _Actor, _StatelessActor

if TYPE_CHECKING:
from .backends.config import ActorPoolConfig
Expand Down Expand Up @@ -180,25 +181,44 @@ def buffer_ref(address: str, buffer: Any) -> BufferRef:
return ctx.buffer_ref(address, buffer)


def file_object_ref(address: str, fileobj: AioFileObject) -> FileObjectRef:
"""
Init file object ref according to address and aio file obj.
Parameters
----------
address
The address of the file obj.
fileobj
Aio file object.
Returns
----------
FileObjectRef obj.
"""
ctx = get_context()
return ctx.file_object_ref(address, fileobj)


async def copy_to(
local_buffers: list,
remote_buffer_refs: List[BufferRef],
local_buffers_or_fileobjs: list,
remote_refs: List[Union[BufferRef, FileObjectRef]],
block_size: Optional[int] = None,
):
"""
Copy data from local buffers to remote buffers.
Copy data from local buffers to remote buffers or copy local file objects to remote file objects.
Parameters
----------
local_buffers
Local buffers.
remote_buffer_refs
Remote buffer refs.
local_buffers_or_fileobjs
Local buffers or file objects.
remote_refs
Remote buffer refs or file object refs.
block_size
Transfer block size when non-ucx
"""
ctx = get_context()
return await ctx.copy_to(local_buffers, remote_buffer_refs, block_size)
return await ctx.copy_to(local_buffers_or_fileobjs, remote_refs, block_size)


async def wait_actor_pool_recovered(address: str, main_pool_address: str | None = None):
Expand Down
69 changes: 49 additions & 20 deletions python/xoscar/backends/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from typing import Any, List, Optional, Tuple, Type, Union

from .._utils import create_actor_ref, to_binary
from ..aio import AioFileObject
from ..api import Actor
from ..context import BaseActorContext
from ..core import ActorRef, BufferRef, create_local_actor_ref
from ..core import ActorRef, BufferRef, FileObjectRef, create_local_actor_ref
from ..debug import debug_async_timeout, detect_cycle_send
from ..errors import CannotCancelTask
from ..utils import dataslots
Expand All @@ -36,6 +37,7 @@
ControlMessage,
ControlMessageType,
CopyToBuffersMessage,
CopyToFileObjectsMessage,
CreateActorMessage,
DestroyActorMessage,
ErrorMessage,
Expand Down Expand Up @@ -281,9 +283,13 @@ def _gen_switch_to_copy_to_control_message(content: Any):
)

@staticmethod
def _gen_copy_to_message(content: Any):
def _gen_copy_to_buffers_message(content: Any):
return CopyToBuffersMessage(message_id=new_message_id(), content=content) # type: ignore

@staticmethod
def _gen_copy_to_fileobjs_message(content: Any):
return CopyToFileObjectsMessage(message_id=new_message_id(), content=content) # type: ignore

async def _get_copy_to_client(self, router, address) -> Client:
client = await self._caller.get_client(router, address)
if isinstance(client, DummyClient) or hasattr(client, "send_buffers"):
Expand All @@ -302,28 +308,19 @@ async def _get_copy_to_client(self, router, address) -> Client:
else:
return await self._caller.get_client_via_type(router, address, client_type)

async def copy_to(
async def _get_client(self, address: str) -> Client:
router = Router.get_instance()
assert router is not None, "`copy_to` can only be used inside pools"
return await self._get_copy_to_client(router, address)

async def copy_to_buffers(
self,
local_buffers: list,
remote_buffer_refs: List[BufferRef],
block_size: Optional[int] = None,
):
assert (
len({ref.address for ref in remote_buffer_refs}) == 1
), "remote buffers for `copy_via_buffers` can support only 1 destination"
assert len(local_buffers) == len(remote_buffer_refs), (
f"Buffers from local and remote must have same size, "
f"local: {len(local_buffers)}, remote: {len(remote_buffer_refs)}"
)
if block_size is not None:
assert (
block_size > 0
), f"`block_size` option must be greater than 0, current value: {block_size}."

router = Router.get_instance()
assert router is not None, "`copy_to` can only be used inside pools"
address = remote_buffer_refs[0].address
client = await self._get_copy_to_client(router, address)
client = await self._get_client(address)
if isinstance(client, UCXClient):
message = [(buf.address, buf.uid) for buf in remote_buffer_refs]
await self._call_send_buffers(
Expand Down Expand Up @@ -352,7 +349,7 @@ async def copy_to(
(r_buf.address, r_buf.uid, last_start, remain, l_buf[:remain])
)
await self._call_with_client(
client, self._gen_copy_to_message(one_block_data)
client, self._gen_copy_to_buffers_message(one_block_data)
)
one_block_data = []
current_buf_size = 0
Expand All @@ -367,5 +364,37 @@ async def copy_to(

if one_block_data:
await self._call_with_client(
client, self._gen_copy_to_message(one_block_data)
client, self._gen_copy_to_buffers_message(one_block_data)
)

async def copy_to_fileobjs(
self,
local_fileobjs: List[AioFileObject],
remote_fileobj_refs: List[FileObjectRef],
block_size: Optional[int] = None,
):
address = remote_fileobj_refs[0].address
client = await self._get_client(address)
block_size = block_size or DEFAULT_TRANSFER_BLOCK_SIZE
one_block_data = []
current_file_size = 0
for file_obj, remote_ref in zip(local_fileobjs, remote_fileobj_refs):
while True:
file_data = await file_obj.read(block_size) # type: ignore
if file_data:
one_block_data.append(
(remote_ref.address, remote_ref.uid, file_data)
)
current_file_size += len(file_data)
if current_file_size >= block_size:
message = self._gen_copy_to_fileobjs_message(one_block_data)
await self._call_with_client(client, message)
one_block_data.clear()
current_file_size = 0
else:
break

if current_file_size > 0:
message = self._gen_copy_to_fileobjs_message(one_block_data)
await self._call_with_client(client, message)
one_block_data.clear()
4 changes: 4 additions & 0 deletions python/xoscar/backends/message.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class MessageType(Enum):
tell = 8
cancel = 9
copy_to_buffers = 10
copy_to_fileobjs = 11

class ControlMessageType(Enum):
stop = 0
Expand Down Expand Up @@ -75,6 +76,9 @@ class CopyToBuffersMessage(_MessageBase):
message_trace: list | None = None,
): ...

class CopyToFileObjectsMessage(CopyToBuffersMessage):
message_type = MessageType.copy_to_fileobjs

class ControlMessage(_MessageBase):
message_type = MessageType.control

Expand Down
6 changes: 6 additions & 0 deletions python/xoscar/backends/message.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class MessageType(Enum):
tell = 8
cancel = 9
copy_to_buffers = 10
copy_to_fileobjs = 11


class ControlMessageType(Enum):
Expand Down Expand Up @@ -520,6 +521,10 @@ cdef class CopyToBuffersMessage(_MessageBase):
self.content = subs[0]


cdef class CopyToFileObjectsMessage(CopyToBuffersMessage):
message_type = MessageType.copy_to_fileobjs


cdef dict _message_type_to_message_cls = {
MessageType.control.value: ControlMessage,
MessageType.result.value: ResultMessage,
Expand All @@ -532,6 +537,7 @@ cdef dict _message_type_to_message_cls = {
MessageType.tell.value: TellMessage,
MessageType.cancel.value: CancelMessage,
MessageType.copy_to_buffers.value: CopyToBuffersMessage,
MessageType.copy_to_fileobjs.value: CopyToFileObjectsMessage
}


Expand Down
13 changes: 10 additions & 3 deletions python/xoscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from .._utils import TypeDispatcher, create_actor_ref, to_binary
from ..api import Actor
from ..core import ActorRef, BufferRef, register_local_pool
from ..core import ActorRef, BufferRef, FileObjectRef, register_local_pool
from ..debug import debug_async_timeout, record_message_trace
from ..entrypoints import init_extension_entrypoints
from ..errors import (
Expand Down Expand Up @@ -123,7 +123,8 @@ def _register_message_handler(pool_type: Type["AbstractActorPool"]):
(MessageType.tell, pool_type.tell),
(MessageType.cancel, pool_type.cancel),
(MessageType.control, pool_type.handle_control_command),
(MessageType.copy_to_buffers, pool_type.handle_copy_to_message),
(MessageType.copy_to_buffers, pool_type.handle_copy_to_buffers_message),
(MessageType.copy_to_fileobjs, pool_type.handle_copy_to_fileobjs_message),
]:
pool_type._message_handler[message_type] = handler # type: ignore
return pool_type
Expand Down Expand Up @@ -500,12 +501,18 @@ async def stop(self):
finally:
self._stopped.set()

async def handle_copy_to_message(self, message) -> ResultMessage:
async def handle_copy_to_buffers_message(self, message) -> ResultMessage:
for addr, uid, start, _len, data in message.content:
buffer = BufferRef.get_buffer(BufferRef(addr, uid))
buffer[start : start + _len] = data
return ResultMessage(message_id=message.message_id, result=True)

async def handle_copy_to_fileobjs_message(self, message) -> ResultMessage:
for addr, uid, data in message.content:
file_obj = FileObjectRef.get_local_file_object(FileObjectRef(addr, uid))
await file_obj.write(data)
return ResultMessage(message_id=message.message_id, result=True)

@property
def stopped(self) -> bool:
return self._stopped.is_set()
Expand Down
96 changes: 94 additions & 2 deletions python/xoscar/backends/test/tests/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@
# limitations under the License.
import asyncio
import os
import shutil
import sys
import tempfile
from typing import List, Optional

import numpy as np
import pytest

from .... import Actor, ActorRefType
from ....api import actor_ref, buffer_ref, copy_to
from ....aio import AioFileObject
from ....api import actor_ref, buffer_ref, copy_to, file_object_ref
from ....backends.allocate_strategy import ProcessIndex
from ....backends.indigen.pool import MainActorPool
from ....context import get_context
from ....core import BufferRef
from ....core import BufferRef, FileObjectRef
from ....tests.core import require_cupy
from ....utils import lazy_import
from ...pool import create_actor_pool
Expand Down Expand Up @@ -203,3 +206,92 @@ async def tests_gpu_copy(scheme):
if "ucx" == scheme:
await _copy_test(None, "ucx", False)
await _copy_test("ucx", None, False)


class FileobjTransferActor(Actor):
def __init__(self):
self._fileobjs = []

async def create_file_objects(self, names: List[str]) -> List[FileObjectRef]:
refs = []
for name in names:
fobj = open(name, "w+b")
afobj = AioFileObject(fobj)
self._fileobjs.append(afobj)
refs.append(file_object_ref(self.address, afobj))
return refs

async def close(self):
for fobj in self._fileobjs:
assert await fobj.tell() > 0
await fobj.close()

async def copy_data(
self,
ref: ActorRefType["FileobjTransferActor"],
names1: List[str],
names2: List[str],
sizes: List[int],
):
fobjs = []
for name, size in zip(names1, sizes):
fobj = open(name, "w+b")
fobj.write(np.random.bytes(size))
fobj.seek(0)
fobjs.append(AioFileObject(fobj))

ref = await actor_ref(ref)
file_obj_refs = await ref.create_file_objects(names2)
await copy_to(fobjs, file_obj_refs)
_ = [await f.close() for f in fobjs] # type: ignore
await ref.close()

for n1, n2 in zip(names1, names2):
with open(n1, "rb") as f1, open(n2, "rb") as f2:
b1 = f1.read()
b2 = f2.read()
assert b1 == b2


@pytest.mark.asyncio
async def test_copy_to_file_objects():
start_method = (
os.environ.get("POOL_START_METHOD", "forkserver")
if sys.platform != "win32"
else None
)
pool = await create_actor_pool(
"127.0.0.1",
pool_cls=MainActorPool,
n_process=2,
subprocess_start_method=start_method,
)

d = tempfile.mkdtemp()
async with pool:
ctx = get_context()

# actor on main pool
actor_ref1 = await ctx.create_actor(
FileobjTransferActor,
uid="test-1",
address=pool.external_address,
allocate_strategy=ProcessIndex(1),
)
actor_ref2 = await ctx.create_actor(
FileobjTransferActor,
uid="test-2",
address=pool.external_address,
allocate_strategy=ProcessIndex(2),
)
sizes = [10 * 1024**2, 3 * 1024**2, 0.5 * 1024**2, 0.25 * 1024**2]
names = []
for _ in range(2 * len(sizes)):
_, p = tempfile.mkstemp(dir=d)
names.append(p)

await actor_ref1.copy_data(actor_ref2, names[::2], names[1::2], sizes=sizes)
try:
shutil.rmtree(d)
except PermissionError:
pass
Loading

0 comments on commit 1ea9836

Please sign in to comment.