Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add session context support for Ray DAG mode #3358

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mars/deploy/oscar/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def new_ray_session(
client = new_cluster_in_ray(backend=backend, **new_cluster_kwargs)
session_id = session_id or client.session.session_id
address = client.address
logger.warning("CLIENT ADDRESS: %s", address)
session = new_session(
address=address, session_id=session_id, backend=backend, default=default
)
Expand Down
2 changes: 2 additions & 0 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,9 @@ async def fetch(self, *tileables, **kwargs) -> list:
chunks, chunk_metas, itertools.chain(*fetch_infos_list)
):
await fetcher.append(chunk.key, meta, fetch_info.indexes)
logger.warning("FETCH!! %r", fetcher)
fetched_data = await fetcher.get()
logger.warning("FETCH2!!")
for fetch_info, data in zip(
itertools.chain(*fetch_infos_list), fetched_data
):
Expand Down
31 changes: 31 additions & 0 deletions mars/deploy/oscar/tests/test_ray_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

import copy
import logging
import os
import time

import pytest

from .... import get_context
from .... import remote as mr
from .... import tensor as mt
from ....session import new_session, get_default_async_session
from ....tests import test_session
Expand Down Expand Up @@ -125,6 +127,35 @@ def test_sync_execute(ray_start_regular_shared2, config):
test_local.test_sync_execute(config)


@require_ray
@pytest.mark.parametrize("config", [{"backend": "ray"}])
def test_spawn_execution(ray_start_regular_shared2, config):
session = new_session(
backend=config["backend"],
n_cpu=2,
web=False,
use_uvloop=False,
config={"task.execution_config.ray.monitor_interval_seconds": 0},
)

assert session._session.client.web_address is None
assert session.get_web_endpoint() is None

def f1(c=0):
if c:
executed = mr.spawn(f1).execute()
logging.warning("EXECUTE DONE!")
executed.fetch()
logging.warning("FETCH DONE!")
return c

with session:
assert 10 == mr.spawn(f1, 10).execute().fetch()

session.stop_server()
assert get_default_async_session() is None


@require_ray
@pytest.mark.parametrize(
"create_cluster",
Expand Down
8 changes: 8 additions & 0 deletions mars/lib/aio/isolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import asyncio
import atexit
import logging
import threading
from typing import Dict, Optional

logger = logging.getLogger(__name__)


class Isolation:
loop: asyncio.AbstractEventLoop
Expand All @@ -31,6 +34,9 @@ def __init__(self, loop: asyncio.AbstractEventLoop, threaded: bool = True):
self._thread = None
self._thread_ident = None

def __repr__(self):
return f"<Isolation loop={id(self.loop)}{self.loop!r} threaded={self._threaded} thread_ident={self._thread_ident}>"

def _run(self):
asyncio.set_event_loop(self.loop)
self._stopped = asyncio.Event()
Expand Down Expand Up @@ -72,9 +78,11 @@ def new_isolation(

if loop is None:
loop = asyncio.new_event_loop()
logger.warning("NEW_LOOP %d", id(loop))

isolation = Isolation(loop, threaded=threaded)
isolation.start()
logger.warning("NEW_ISOLATION! loop: %r", loop)
_name_to_isolation[name] = isolation
return isolation

Expand Down
4 changes: 3 additions & 1 deletion mars/services/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
local_address: str,
loop: asyncio.AbstractEventLoop,
band: BandType = None,
isolation_threaded: bool = False,
):
super().__init__(
session_id=session_id,
Expand All @@ -59,7 +60,8 @@ def __init__(
# new isolation with current loop,
# so that session created in tile and execute
# can get the right isolation
new_isolation(loop=self._loop, threaded=False)
logger.warning("NEW_ISOLATION in ThreadedServiceContext.__init__")
new_isolation(loop=self._loop, threaded=isolation_threaded)

self._running_session_id = None
self._running_op_key = None
Expand Down
23 changes: 19 additions & 4 deletions mars/services/task/execution/ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from typing import Dict, List, Callable

from .....core.context import Context
from .....session import ensure_isolation_created
from .....storage.base import StorageLevel
from .....typing import ChunkType
from .....typing import ChunkType, SessionType
from .....utils import implements, lazy_import, sync_to_async
from ....context import ThreadedServiceContext
from .config import RayExecutionConfig
Expand Down Expand Up @@ -187,13 +188,27 @@ def get_worker_addresses(self) -> List[str]:


# TODO(fyrestone): Implement more APIs for Ray.
class RayExecutionWorkerContext(_RayRemoteObjectContext, dict):
class RayExecutionWorkerContext(_RayRemoteObjectContext, ThreadedServiceContext, dict):
"""The context for executing operands."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self,
get_or_create_actor: Callable[[], "ray.actor.ActorHandle"],
*args,
**kwargs,
):
_RayRemoteObjectContext.__init__(self, get_or_create_actor, *args, loop=None, isolation_threaded=True, **kwargs)
dict.__init__(self)
self._current_chunk = None

@implements(Context.get_current_session)
def get_current_session(self) -> SessionType:
from .....session import new_session

return new_session(
self.supervisor_address, self.session_id, backend="ray", new=False, default=False
)

@classmethod
@implements(Context.new_custom_log_dir)
def new_custom_log_dir(cls):
Expand Down
64 changes: 49 additions & 15 deletions mars/services/task/execution/ray/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
import itertools
import logging
import operator
import os
import time
from dataclasses import dataclass, field
from typing import List, Dict, Any, Callable

import numpy as np

from .....core import ChunkGraph, Chunk, TileContext
from .....core.context import set_context
from .....core.context import set_context, get_context
from .....core.operand import (
Fetch,
Fuse,
Expand All @@ -38,6 +39,7 @@
from .....metrics.api import init_metrics, Metrics
from .....resource import Resource
from .....serialization import serialize, deserialize
from .....session import AbstractSession, get_default_session
from .....typing import BandType
from .....utils import (
aiotask_wrapper,
Expand Down Expand Up @@ -149,10 +151,12 @@ def gc_inputs(self, chunk: Chunk):


def execute_subtask(
session_id: str,
subtask_id: str,
subtask_chunk_graph: ChunkGraph,
output_meta_n_keys: int,
is_mapper,
address: str,
*inputs,
):
"""
Expand All @@ -176,6 +180,9 @@ def execute_subtask(
-------
subtask outputs and meta for outputs if `output_meta_keys` is provided.
"""
logging.basicConfig(level=logging.INFO)
logger.setLevel(logging.INFO)

init_metrics("ray")
started_subtask_number.record(1)
ray_task_id = ray.get_runtime_context().get_task_id()
Expand All @@ -184,7 +191,16 @@ def execute_subtask(
# Optimize chunk graph.
subtask_chunk_graph = _optimize_subtask_graph(subtask_chunk_graph)
fetch_chunks, shuffle_fetch_chunk = _get_fetch_chunks(subtask_chunk_graph)
context = RayExecutionWorkerContext(RayTaskState.get_handle)

context = RayExecutionWorkerContext(
RayTaskState.get_handle,
session_id,
address,
address,
address,
)
set_context(context)

if shuffle_fetch_chunk is not None:
# The subtask is a reducer subtask.
n_mappers = shuffle_fetch_chunk.op.n_mappers
Expand All @@ -209,19 +225,28 @@ def execute_subtask(
# Update non shuffle inputs to context.
context.update(zip((start_chunk.key for start_chunk in fetch_chunks), inputs))

for chunk in subtask_chunk_graph.topological_iter():
if chunk.key not in context:
try:
context.set_current_chunk(chunk)
execute(context, chunk.op)
except Exception:
logger.exception(
"Execute operand %s of graph %s failed.",
chunk.op,
subtask_chunk_graph.to_dot(),
)
raise
subtask_gc.gc_inputs(chunk)
default_session = get_default_session()
try:
context.get_current_session().as_default()

for chunk in subtask_chunk_graph.topological_iter():
if chunk.key not in context:
try:
context.set_current_chunk(chunk)
execute(context, chunk.op)
except Exception:
logger.exception(
"Execute operand %s of graph %s failed.",
chunk.op,
subtask_chunk_graph.to_dot(),
)
raise
subtask_gc.gc_inputs(chunk)
finally:
if default_session is not None:
default_session.as_default()
else:
AbstractSession.reset_default()

# For non-mapper subtask, output context is chunk key to results.
# For mapper subtasks, output context is data key to results.
Expand Down Expand Up @@ -455,6 +480,7 @@ def __init__(
task_chunks_meta: Dict[str, _RayChunkMeta],
lifecycle_api: LifecycleAPI,
meta_api: MetaAPI,
address: str,
):
logger.info(
"Start task %s with GC method %s.",
Expand All @@ -475,6 +501,8 @@ def __init__(
self._available_band_resources = None
self._result_tileables_lifecycle = None

self._address = address

# For progress and task cancel
self._stage_index = 0
self._pre_all_stages_progress = 0.0
Expand Down Expand Up @@ -507,6 +535,7 @@ async def create(
task_chunks_meta,
lifecycle_api,
meta_api,
address,
)
available_band_resources = await executor.get_available_band_resources()
worker_addresses = list(
Expand Down Expand Up @@ -710,10 +739,12 @@ async def _execute_subtask_graph(
memory=subtask_memory,
scheduling_strategy="DEFAULT" if len(input_object_refs) else "SPREAD",
).remote(
subtask.session_id,
subtask.subtask_id,
serialize(subtask_chunk_graph, context={"serializer": "ray"}),
subtask.stage_n_outputs,
is_mapper,
self._address,
*input_object_refs,
)
await asyncio.sleep(0)
Expand All @@ -739,6 +770,7 @@ async def _execute_subtask_graph(
task_context[chunk_key] = object_ref
logger.info("Submitted %s subtasks of stage %s.", len(subtask_graph), stage_id)

logger.warning("SUBTASK_RUN_1")
monitor_context.stage = _RayExecutionStage.WAITING
key_to_meta = {}
if len(output_meta_object_refs) > 0:
Expand All @@ -752,6 +784,7 @@ async def _execute_subtask_graph(
self._task_chunks_meta[key] = _RayChunkMeta(memory_size=memory_size)
logger.info("Got %s metas of stage %s.", meta_count, stage_id)

logger.warning("SUBTASK_RUN_2")
chunk_to_meta = {}
# ray.wait requires the object ref list is unique.
output_object_refs = set()
Expand All @@ -773,6 +806,7 @@ async def _execute_subtask_graph(
await asyncio.to_thread(ray.wait, list(output_object_refs), fetch_local=False)

logger.info("Complete stage %s.", stage_id)
logger.warning("%d: SUBTASK_RUN_3: %r", os.getpid(), output_object_refs)
return chunk_to_meta

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down
14 changes: 11 additions & 3 deletions mars/services/task/execution/ray/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

import asyncio
import functools
import logging
from collections import namedtuple
from typing import Dict, List

from .....utils import lazy_import
from ..api import Fetcher, register_fetcher_cls

logger = logging.getLogger(__name__)

ray = lazy_import("ray")
_FetchInfo = namedtuple("FetchInfo", ["key", "object_ref", "conditions"])

Expand All @@ -36,9 +39,10 @@ class RayFetcher(Fetcher):
name = "ray"
required_meta_keys = ("object_refs",)

def __init__(self, **kwargs):
def __init__(self, loop=None, **kwargs):
self._fetch_info_list = []
self._no_conditions = True
self._loop = loop

@staticmethod
@functools.lru_cache(maxsize=None) # Specify maxsize=None to make it faster
Expand All @@ -55,9 +59,12 @@ async def append(self, chunk_key: str, chunk_meta: Dict, conditions: List = None

async def get(self):
if self._no_conditions:
logger.warning(f"FETCHER_0 {self._fetch_info_list}")
return await asyncio.gather(
*(info.object_ref for info in self._fetch_info_list)
*(info.object_ref for info in self._fetch_info_list),
loop=self._loop,
)
logger.warning("FETCHER_1")
refs = [None] * len(self._fetch_info_list)
for index, fetch_info in enumerate(self._fetch_info_list):
if fetch_info.conditions is None:
Expand All @@ -66,4 +73,5 @@ async def get(self):
refs[index] = self._remote_query_object_with_condition().remote(
fetch_info.object_ref, tuple(fetch_info.conditions)
)
return await asyncio.gather(*refs)
logger.warning("FETCHER_2")
return await asyncio.gather(*refs, loop=self._loop)
Loading