Skip to content

Commit

Permalink
also kill other tpu jobs in case we get non-RayTaskErrors
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Oct 23, 2024
1 parent e6033d1 commit b2636d4
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.Objec

@ray.remote(resources={f"TPU-{tpu_type}-head": 1})
def do_run(remote_fn) -> _TpuRunResult:
logging.basicConfig(level=logging.INFO)
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4
remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts)

Expand All @@ -92,6 +93,13 @@ def do_run(remote_fn) -> _TpuRunResult:
except Exception:
logger.exception("Failed to kill job after primary failure")
return _handle_ray_error(info, e)
except Exception as e:
for f in futures:
try:
ray.cancel(f)
except Exception:
logger.exception("Failed to kill job after primary failure")
return TpuFailed(info, e)

return do_run.remote(remote_fn)

Expand Down Expand Up @@ -144,12 +152,12 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re
out = ray.get(run_on_pod(remote_fn, tpu_type))
except ray.exceptions.RayTaskError as e:
problem = e
if "preempted" in str(e):
if "preempted" in str(e).lower():
num_preemptions += 1
logger.warning(f"Preempted {num_preemptions} times, {e}")
else:
num_failures += 1
logger.warning(f"Failed {num_failures} times")
logger.warning(f"Failed {num_failures} times", exc_info=e)
continue
except Exception as e:
problem = e
Expand Down

0 comments on commit b2636d4

Please sign in to comment.