Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement parallel execution for DAG tasks #4128

Open
wants to merge 56 commits into
base: advanced-dag
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
7c0965a
provide an example, edited from pipeline.yml
andylizf Oct 10, 2024
6949ac1
more focus on dependencies for user dag lib
andylizf Oct 10, 2024
55db40b
more powerful user interface
andylizf Oct 11, 2024
db7ff9f
load and dump new yaml format
andylizf Oct 11, 2024
054cc26
fix
andylizf Oct 11, 2024
24ef94e
fix: reversed logic in add_edge
andylizf Oct 11, 2024
129bdbf
rename
andylizf Oct 14, 2024
12ec5a4
refactor due to reviewer's comments
andylizf Oct 14, 2024
9497a3e
generate task.name if not given
andylizf Oct 14, 2024
ff528a5
add comments for add_edge
andylizf Oct 15, 2024
04c6f9d
add `print_exception_no_traceback` when raise
andylizf Oct 15, 2024
4985813
make `Dag.tasks` a property
andylizf Oct 15, 2024
48a2826
print dependencies for `__repr__`
andylizf Oct 15, 2024
78d826d
move `get_unique_task_name` to common_utils
andylizf Oct 15, 2024
e88acc1
rename methods to use downstream/edge terminology
andylizf Oct 17, 2024
651789d
Add dependencies feature for task dependency management (#4067)
andylizf Oct 10, 2024
4bc8b89
fix(jobs): type errors
andylizf Oct 19, 2024
4ba76c3
refactor: `_update_failed_task_state` for unified error handling
andylizf Oct 19, 2024
e1b27f3
refactor: separate finally block for a meaningful name
andylizf Oct 19, 2024
c102f5d
feat: simple parallel execution support
andylizf Oct 19, 2024
e4fbb28
Apply suggestions from code review
andylizf Oct 20, 2024
a27969b
change wording all to up/downstream style
andylizf Oct 20, 2024
8486352
Add unique suffix to task names, fallback to timestamp if unnamed
andylizf Oct 20, 2024
c14980e
Unify handling of single and multiple tasks without dependencies
andylizf Oct 20, 2024
66fc864
Refactor tasks initialization: use list comprehension and fail fast
andylizf Oct 20, 2024
65d0bdd
Fix remove task dependency description: upstream, not downstream
andylizf Oct 20, 2024
28b6482
Remove duplicated `self.edges`, use nx api instead
andylizf Oct 20, 2024
1792ba6
Revert "Add unique suffix to task names, fallback to timestamp if unn…
andylizf Oct 20, 2024
f600d16
comment the checking used as upstream logic
andylizf Oct 20, 2024
a673e22
remove is_chain restriction in jobs launch
andylizf Oct 20, 2024
5dfeb32
Change Static layered parallelism to dynamic queue parallelism
andylizf Oct 21, 2024
4c91f94
Merge branch 'job-dag' into dag-execute
andylizf Oct 21, 2024
7e500ee
fix: add blocked_tasks set
andylizf Oct 21, 2024
4c8954e
refactor and update canceled tasks in database
andylizf Oct 21, 2024
da42a24
format: some types are unsubscriptable
andylizf Oct 21, 2024
25281f6
add some logging
andylizf Oct 21, 2024
aa763b1
Merge remote-tracking branch 'upstream/advanced-dag' into dag-execute
andylizf Oct 21, 2024
0938158
Fooled again by the silly finally block, mistakenly thought it only r…
andylizf Oct 22, 2024
bb94930
Merge remote-tracking branch 'upstream/advanced-dag' into dag-execute
andylizf Oct 23, 2024
98044c9
feat: redirect logging for each thread to a separate file to prevent …
andylizf Oct 24, 2024
d107a73
fix due to reviwer's suggestions and some nits
andylizf Oct 26, 2024
43e0f19
chore: remove some debugging info
andylizf Oct 26, 2024
b001b9c
partially revert "update canceled tasks in database", view a task as …
andylizf Oct 26, 2024
ba4a0d1
add some comments to inform the future of thread-level redirector
andylizf Oct 26, 2024
6b3b4c8
cancell all tasks when a task failed
andylizf Oct 26, 2024
934bde3
combine 3 sets to 1 dict
andylizf Oct 26, 2024
59195da
add some comments to illustrate `_try_add_successors_to_queue` only q…
andylizf Oct 27, 2024
609d57d
Cancel all non-running tasks when cancelling a job. Left a TODO for f…
andylizf Oct 27, 2024
9d577e5
make those logging files and dir
andylizf Oct 27, 2024
c50a881
Merge branch 'advanced-dag' into dag-execute
andylizf Oct 27, 2024
af16aab
refactor: stream_logs_by_id
andylizf Oct 27, 2024
0095b98
refactor and format
andylizf Oct 27, 2024
cecdd7b
provide a cli to print run.log of a certain subtask
andylizf Oct 28, 2024
837d7ab
format
andylizf Oct 28, 2024
5f7e50d
add some comments and checks
andylizf Oct 28, 2024
b9e143d
clearer log
andylizf Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3281,7 +3281,7 @@ def _exec_code_on_head(
) -> None:
"""Executes generated code on the head node."""
style = colorama.Style

fore = colorama.Fore
script_path = os.path.join(SKY_REMOTE_APP_DIR, f'sky_job_{job_id}')
remote_log_dir = self.log_dir
remote_log_path = os.path.join(remote_log_dir, 'run.log')
Expand All @@ -3291,7 +3291,7 @@ def _exec_code_on_head(
mkdir_code = (f'{cd} && mkdir -p {remote_log_dir} && '
f'touch {remote_log_path}')
encoded_script = shlex.quote(codegen)
create_script_code = (f'{{ echo {encoded_script} > {script_path}; }}')
create_script_code = f'{{ echo {encoded_script} > {script_path}; }}'
job_submit_cmd = (
f'RAY_DASHBOARD_PORT=$({constants.SKY_PYTHON_CMD} -c "from sky.skylet import job_lib; print(job_lib.get_job_submission_port())" 2> /dev/null || echo 8265);' # pylint: disable=line-too-long
f'{cd} && {constants.SKY_RAY_CMD} job submit '
Expand Down Expand Up @@ -3377,11 +3377,17 @@ def _dump_code_to_file(codegen: str) -> None:
logger.info(
ux_utils.starting_message(f'Job submitted, ID: {job_id}'))
rich_utils.stop_safe_status()

is_dag_chain = managed_job_dag is not None and managed_job_dag.is_chain(
)
try:
if not detach_run:
if (handle.cluster_name in controller_utils.Controllers.
JOBS_CONTROLLER.value.candidate_cluster_names):
self.tail_managed_job_logs(handle, job_id)
if is_dag_chain:
self.tail_managed_job_logs(handle, job_id)
else:
pass
else:
# Sky logs. Not using subprocess.run since it will make the
# ssh keep connected after ctrl-c.
Expand All @@ -3390,15 +3396,20 @@ def _dump_code_to_file(codegen: str) -> None:
name = handle.cluster_name
controller = controller_utils.Controllers.from_name(name)
if controller == controller_utils.Controllers.JOBS_CONTROLLER:
cluster_logs = (f'\n{ux_utils.INDENT_SYMBOL}To stream job logs '
f'for chain workflow:\t\t\t'
f'{ux_utils.BOLD}sky jobs logs {job_id}'
) if is_dag_chain else ''
logger.info(
f'\n📋 Useful Commands'
f'\nManaged Job ID: '
f'\n{fore.CYAN}Managed Job ID: '
f'{style.BRIGHT}{job_id}{style.RESET_ALL}'
f'\n📋 Useful Commands'
f'\n{ux_utils.INDENT_SYMBOL}To cancel the job:\t\t\t'
f'{ux_utils.BOLD}sky jobs cancel {job_id}'
f'{ux_utils.RESET_BOLD}'
f'\n{ux_utils.INDENT_SYMBOL}To stream job logs:\t\t\t'
f'{ux_utils.BOLD}sky jobs logs {job_id}'
f'{cluster_logs}'
f'\n{ux_utils.INDENT_SYMBOL}To stream job DAG logs:\t\t\t'
f'{ux_utils.BOLD}sky jobs logs {job_id} --task-id [TASK_ID]'
f'{ux_utils.RESET_BOLD}'
f'\n{ux_utils.INDENT_SYMBOL}To stream controller logs:\t\t'
f'{ux_utils.BOLD}sky jobs logs --controller {job_id}'
Expand All @@ -3410,8 +3421,9 @@ def _dump_code_to_file(codegen: str) -> None:
f'dashboard:\t{ux_utils.BOLD}sky jobs dashboard'
f'{ux_utils.RESET_BOLD}')
elif controller is None:
logger.info(f'\n📋 Useful Commands'
f'\nJob ID: {job_id}'
logger.info(f'\n{fore.CYAN}Job ID: '
f'{style.BRIGHT}{job_id}{style.RESET_ALL}'
f'\n📋 Useful Commands'
f'\n{ux_utils.INDENT_SYMBOL}To cancel the job:\t\t'
f'{ux_utils.BOLD}sky cancel {name} {job_id}'
f'{ux_utils.RESET_BOLD}'
Expand Down Expand Up @@ -3773,12 +3785,13 @@ def tail_managed_job_logs(self,
handle: CloudVmRayResourceHandle,
job_id: Optional[int] = None,
job_name: Optional[str] = None,
task_id: Optional[int] = None,
controller: bool = False,
follow: bool = True) -> None:
# if job_name is not None, job_id should be None
assert job_name is None or job_id is None, (job_name, job_id)
code = managed_jobs.ManagedJobCodeGen.stream_logs(
job_name, job_id, follow, controller)
job_name, job_id, task_id, follow, controller)

# With the stdin=subprocess.DEVNULL, the ctrl-c will not directly
# kill the process, so we need to handle it manually here.
Expand Down
21 changes: 19 additions & 2 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3830,14 +3830,31 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
default=False,
help=('Show the controller logs of this job; useful for debugging '
'launching/recoveries, etc.'))
@click.option('--task-id',
required=False,
type=int,
help='Tail the logs of a specific task.')
@click.argument('job_id', required=False, type=int)
@usage_lib.entrypoint
def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool,
controller: bool):
def jobs_logs(name: Optional[str], job_id: Optional[int],
task_id: Optional[int], follow: bool, controller: bool):
"""Tail the log of a managed job."""
if name is not None and job_id is not None:
raise ValueError('Cannot specify both name and job_id.')

if task_id is not None:
if job_id is None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Must specify job_id when specifying task_id.')
if controller:
with ux_utils.print_exception_no_traceback():
raise ValueError('Cannot specify both task_id and controller.')
# TODO(andy): Add validation to ensure either `--task-id` or `--controller`
# is specified when dealing with non-linear job DAGs.
try:
managed_jobs.tail_logs(name=name,
job_id=job_id,
task_id=task_id,
follow=follow,
controller=controller)
except exceptions.ClusterNotUpError:
Expand Down
202 changes: 140 additions & 62 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Controller: handles the life cycle of a managed job."""
import argparse
from concurrent import futures
import enum
import multiprocessing
import os
import pathlib
import queue
import time
import traceback
import typing
from typing import Tuple
from typing import Dict, Tuple

import filelock

Expand Down Expand Up @@ -43,25 +46,36 @@ def _get_dag_and_name(dag_yaml: str) -> Tuple['sky.Dag', str]:
return dag, dag_name


class TaskStatus(enum.Enum):
COMPLETED = 'completed'
FAILED = 'failed'
CANCELLED = 'cancelled'


class JobsController:
"""Each jobs controller manages the life cycle of one managed job."""

def __init__(self, job_id: int, dag_yaml: str,
retry_until_up: bool) -> None:
self._job_id = job_id
self._dag, self._dag_name = _get_dag_and_name(dag_yaml)
self._num_tasks = len(self._dag.tasks)
logger.info(self._dag)
self._retry_until_up = retry_until_up
# TODO(zhwu): this assumes the specific backend.
self._backend = cloud_vm_ray_backend.CloudVmRayBackend()

# pylint: disable=line-too-long
self._dag_graph = self._dag.get_graph()
self._task_queue = self._initialize_task_queue()
self._task_status: Dict['sky.Task', TaskStatus] = {}

# Add a unique identifier to the task environment variables, so that
# the user can have the same id for multiple recoveries.
# Example value: sky-2022-10-04-22-46-52-467694_my-spot-name_spot_id-17-0
# Example value:
# sky-2022-10-04-22-46-52-467694_my-spot-name_spot_id-17-0
job_id_env_vars = []
for i, task in enumerate(self._dag.tasks):
if len(self._dag.tasks) <= 1:
if self._num_tasks <= 1:
task_name = self._dag_name
else:
assert task.name is not None, task
Expand All @@ -86,6 +100,14 @@ def __init__(self, job_id: int, dag_yaml: str,
job_id_env_vars)
task.update_envs(task_envs)

def _initialize_task_queue(self) -> queue.Queue:
task_queue: queue.Queue = queue.Queue()
for task in self._dag_graph.nodes():
if self._dag_graph.in_degree(task) == 0:
task_id = self._dag.tasks.index(task)
task_queue.put(task_id)
return task_queue

def _download_log_and_stream(
self,
handle: cloud_vm_ray_backend.CloudVmRayResourceHandle) -> None:
Expand Down Expand Up @@ -323,78 +345,133 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
recovered_time=recovered_time,
callback_func=callback_func)

def run(self):
"""Run controller logic and handle exceptions."""
task_id = 0
def _try_add_successors_to_queue(self, task_id: int):
"""Tasks with multiple predecessors will only be queued once, as
`_handle_future_completion` runs sequentially in the main thread via
`futures.wait()`.
"""
is_task_runnable = lambda task: (all(
self._task_status.get(pred) == TaskStatus.COMPLETED
for pred in self._dag_graph.predecessors(task)
) and self._task_status.get(task) != TaskStatus.CANCELLED)
task = self._dag.tasks[task_id]
for successor in self._dag_graph.successors(task):
successor_id = self._dag.tasks.index(successor)
if is_task_runnable(successor):
self._task_queue.put(successor_id)
andylizf marked this conversation as resolved.
Show resolved Hide resolved

def _handle_future_completion(self, future: futures.Future, task_id: int):
succeeded = False
try:
succeeded = True
# We support chain DAGs only for now.
for task_id, task in enumerate(self._dag.tasks):
succeeded = self._run_one_task(task_id, task)
if not succeeded:
break
succeeded = future.result()
except exceptions.ProvisionPrechecksError as e:
# Please refer to the docstring of self._run for the cases when
# this exception can occur.
# Please refer to the docstring of self._run for
# the cases when this exception can occur.
failure_reason = ('; '.join(
common_utils.format_exception(reason, use_bracket=True)
for reason in e.reasons))
logger.error(failure_reason)
managed_job_state.set_failed(
self._job_id,
task_id=task_id,
failure_type=managed_job_state.ManagedJobStatus.
FAILED_PRECHECKS,
failure_reason=failure_reason,
callback_func=managed_job_utils.event_callback_func(
job_id=self._job_id,
task_id=task_id,
task=self._dag.tasks[task_id]))
self._update_failed_task_state(
task_id, managed_job_state.ManagedJobStatus.FAILED_PRECHECKS,
failure_reason)
except exceptions.ManagedJobReachedMaxRetriesError as e:
# Please refer to the docstring of self._run for the cases when
# this exception can occur.
logger.error(common_utils.format_exception(e))
# The managed job should be marked as FAILED_NO_RESOURCE, as the
# managed job may be able to launch next time.
managed_job_state.set_failed(
self._job_id,
task_id=task_id,
failure_type=managed_job_state.ManagedJobStatus.
FAILED_NO_RESOURCE,
failure_reason=common_utils.format_exception(e),
callback_func=managed_job_utils.event_callback_func(
job_id=self._job_id,
task_id=task_id,
task=self._dag.tasks[task_id]))
# Please refer to the docstring of self._run for
# the cases when this exception can occur.
failure_reason = common_utils.format_exception(e)
logger.error(failure_reason)
# The managed job should be marked as
# FAILED_NO_RESOURCE, as the managed job may be able to
# launch next time.
self._update_failed_task_state(
task_id, managed_job_state.ManagedJobStatus.FAILED_NO_RESOURCE,
failure_reason)
except (Exception, SystemExit) as e: # pylint: disable=broad-except
with ux_utils.enable_traceback():
logger.error(traceback.format_exc())
msg = ('Unexpected error occurred: '
f'{common_utils.format_exception(e, use_bracket=True)}')
msg = ('Unexpected error occurred: ' +
common_utils.format_exception(e, use_bracket=True))
logger.error(msg)
managed_job_state.set_failed(
self._job_id,
task_id=task_id,
failure_type=managed_job_state.ManagedJobStatus.
FAILED_CONTROLLER,
failure_reason=msg,
callback_func=managed_job_utils.event_callback_func(
job_id=self._job_id,
task_id=task_id,
task=self._dag.tasks[task_id]))
self._update_failed_task_state(
task_id, managed_job_state.ManagedJobStatus.FAILED_CONTROLLER,
msg)
finally:
# This will set all unfinished tasks to CANCELLING, and will not
# affect the jobs in terminal states.
# We need to call set_cancelling before set_cancelled to make sure
# the table entries are correctly set.
callback_func = managed_job_utils.event_callback_func(
task = self._dag.tasks[task_id]
if succeeded:
logger.info(
f'Task {task_id} completed with result: {succeeded}')
self._task_status[task] = TaskStatus.COMPLETED
self._try_add_successors_to_queue(task_id)
else:
logger.info(f'Adding task {task_id} to failed tasks.')
self._task_status[task] = TaskStatus.FAILED
self._cancel_all_tasks(task_id)

def _cancel_all_tasks(self, task_id: int):
callback_func = managed_job_utils.event_callback_func(
job_id=self._job_id, task_id=task_id, task=self._dag.tasks[task_id])
for task in self._dag.tasks:
if task not in self._task_status:
self._task_status.setdefault(task, TaskStatus.CANCELLED)

# Call set_cancelling before set_cancelled to make sure the table
# entries are correctly set.
managed_job_state.set_cancelling(self._job_id, callback_func)
managed_job_state.set_cancelled(self._job_id, callback_func)

def run(self):
"""Run controller logic and handle exceptions."""
all_tasks_completed = lambda: self._num_tasks == len(self._task_status)
# TODO(andy):Serve has a logic to prevent from too many services running
# at the same time. We should have a similar logic here, but instead we
# should calculate the sum of the subtasks (an upper bound), instead of
# the number of jobs (dags).
# Further, we could try to calculate the maximum concurrency in the dag
# (e.g. for a chain dag it is 1 instead of n), which could allow us to
# run more dags in parallel.
max_workers = self._num_tasks
managed_job_utils.make_launch_log_dir_for_redirection(
self._backend.run_timestamp, self._num_tasks)
with futures.ThreadPoolExecutor(max_workers) as executor:
future_to_task = {}
while not all_tasks_completed():
while not self._task_queue.empty():
task_id = self._task_queue.get()
log_file_name = managed_job_utils.get_launch_log_file_name(
self._backend.run_timestamp, task_id)

logger.info(
f'Task {task_id} is submitted to run. To prevent '
f'from interleaving, the launch logs are redirected to '
f'{ux_utils.BOLD}{log_file_name}{ux_utils.RESET_BOLD}')

with ux_utils.RedirectOutputForThread() as redirector:
future = executor.submit(
redirector.run(self._run_one_task, log_file_name),
task_id, self._dag.tasks[task_id])
future_to_task[future] = task_id

done, _ = futures.wait(future_to_task.keys(),
return_when='FIRST_COMPLETED')

for future in done:
logger.info(f'Task {future_to_task[future]} completed.')
task_id = future_to_task.pop(future)
self._handle_future_completion(future, task_id)

def _update_failed_task_state(
andylizf marked this conversation as resolved.
Show resolved Hide resolved
self, task_id: int,
failure_type: managed_job_state.ManagedJobStatus,
failure_reason: str):
"""Update the state of the failed task."""
managed_job_state.set_failed(
self._job_id,
task_id=task_id,
failure_type=failure_type,
failure_reason=failure_reason,
callback_func=managed_job_utils.event_callback_func(
job_id=self._job_id,
task_id=task_id,
task=self._dag.tasks[task_id])
managed_job_state.set_cancelling(job_id=self._job_id,
callback_func=callback_func)
managed_job_state.set_cancelled(job_id=self._job_id,
callback_func=callback_func)
task=self._dag.tasks[task_id]))


def _run_controller(job_id: int, dag_yaml: str, retry_until_up: bool):
Expand Down Expand Up @@ -480,6 +557,7 @@ def start(job_id, dag_yaml, retry_until_up):
except exceptions.ManagedJobUserCancelledError:
dag, _ = _get_dag_and_name(dag_yaml)
task_id, _ = managed_job_state.get_latest_task_id_status(job_id)
assert task_id is not None, task_id
logger.info(
f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}')
managed_job_state.set_cancelling(
Expand Down
Loading
Loading