Skip to content

Commit

Permalink
fix due to reviwer's suggestions and some nits
Browse files Browse the repository at this point in the history
  • Loading branch information
andylizf committed Oct 26, 2024
1 parent 98044c9 commit d107a73
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 24 deletions.
2 changes: 1 addition & 1 deletion sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3291,7 +3291,7 @@ def _exec_code_on_head(
mkdir_code = (f'{cd} && mkdir -p {remote_log_dir} && '
f'touch {remote_log_path}')
encoded_script = shlex.quote(codegen)
create_script_code = (f'{{ echo {encoded_script} > {script_path}; }}')
create_script_code = f'{{ echo {encoded_script} > {script_path}; }}'
job_submit_cmd = (
f'RAY_DASHBOARD_PORT=$({constants.SKY_PYTHON_CMD} -c "from sky.skylet import job_lib; print(job_lib.get_job_submission_port())" 2> /dev/null || echo 8265);' # pylint: disable=line-too-long
f'{cd} && {constants.SKY_RAY_CMD} job submit '
Expand Down
30 changes: 17 additions & 13 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
recovered_time=recovered_time,
callback_func=callback_func)

def _try_add_successors_to_queue(self, task_id):
def _try_add_successors_to_queue(self, task_id: int):
is_task_runnable = lambda task: (all(
self._dag.tasks.index(pred) in self._completed_tasks
for pred in self._dag_graph.predecessors(task)) and task_id not in
Expand All @@ -355,8 +355,6 @@ def _handle_future_completion(self, future: futures.Future, task_id: int):
try:
succeeded = future.result()
except exceptions.ProvisionPrechecksError as e:
logger.info(f'Task {task_id} failed with ProvisionPrechecksError '
f'{e}')
# Please refer to the docstring of self._run for
# the cases when this exception can occur.
failure_reason = ('; '.join(
Expand Down Expand Up @@ -420,27 +418,33 @@ def _cancel_downstream_tasks(self, task_id: int):

def run(self):
"""Run controller logic and handle exceptions."""
all_tasks_completed = lambda: (len(self._completed_tasks) + len(
self._failed_tasks) + len(self._block_tasks) == len(self._dag.tasks)
)
with futures.ThreadPoolExecutor(
max_workers=len(self._dag.tasks)) as executor:
all_tasks_completed = lambda: len(self._dag.tasks) == (len(
self._completed_tasks) + len(self._failed_tasks) + len(
self._block_tasks))
# TODO(andy):Serve has a logic to prevent from too many services running
# at the same time. We should have a similar logic here, but instead we
# should calculate the sum of the subtasks (an upper bound), instead of
# the number of jobs (dags).
# Further, we could try to calculate the maximum concurrency in the dag
# (e.g. for a chain dag it is 1 instead of n), which could allow us to
# run more dags in parallel.
max_workers = len(self._dag.tasks)
with futures.ThreadPoolExecutor(max_workers) as executor:
future_to_task = {}
while not all_tasks_completed():
while not self._task_queue.empty():
task_id = self._task_queue.get()
logger.info(f'Submitting task {task_id} to executor.')
log_file_name = os.path.join(
constants.SKY_LOGS_DIRECTORY, 'managed_jobs',
f'replica_{task_id}_launch.log')
log_file_name = os.path.join(constants.SKY_LOGS_DIRECTORY,
'managed_jobs',
f'task_{task_id}_launch.log')
with ux_utils.RedirectOutputForThread() as redirector:
future = executor.submit(
redirector.wrap(self._run_one_task, log_file_name),
redirector.run(self._run_one_task, log_file_name),
task_id, self._dag.tasks[task_id])
future_to_task[future] = task_id

done, _ = futures.wait(future_to_task.keys(),
timeout=1,
return_when='FIRST_COMPLETED')

for future in done:
Expand Down
6 changes: 2 additions & 4 deletions sky/jobs/recovery_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
import traceback
import typing
from typing import cast, List, Optional
from typing import Optional

import sky
from sky import backends
Expand Down Expand Up @@ -320,9 +320,7 @@ def _launch(self,
# Failing directly avoids the infinite loop of retrying
# the launch when, e.g., an invalid cluster name is used
# and --retry-until-up is specified.
reasons = cast(
List[Exception],
e.failover_history if e.failover_history else [e])
reasons = e.failover_history if e.failover_history else [e]
reasons_str = '; '.join(
common_utils.format_exception(err) for err in reasons)
logger.error(
Expand Down
1 change: 0 additions & 1 deletion sky/serve/serve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
if typing.TYPE_CHECKING:
import fastapi

from sky import task as task_lib
from sky.serve import replica_managers

SKY_SERVE_CONTROLLER_NAME: str = (
Expand Down
5 changes: 1 addition & 4 deletions sky/skylet/log_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,8 @@ def run_bash_command_with_log(bash_command: str,
# Need this `-i` option to make sure `source ~/.bashrc` work.
inner_command = f'/bin/bash -i {script_path}'

subprocess_cmd: Union[str, List[str]]
subprocess_cmd = inner_command

return run_with_log(
subprocess_cmd,
inner_command,
log_path,
stream_logs=stream_logs,
with_ray=with_ray,
Expand Down
2 changes: 1 addition & 1 deletion sky/utils/ux_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout, sys.stderr = (self._thread_aware_stdout,
self._thread_aware_stderr)

def wrap(self, func: Callable, filepath: str, mode: str = 'w'):
def run(self, func: Callable, filepath: str, mode: str = 'w'):
"""Wraps a function to redirect its output to a specific file."""

def wrapped(*args, **kwargs) -> Any:
Expand Down

0 comments on commit d107a73

Please sign in to comment.