diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 3ae5d0105..2dc554808 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -76,6 +76,7 @@ def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.Objec @ray.remote(resources={f"TPU-{tpu_type}-head": 1}) def do_run(remote_fn) -> _TpuRunResult: + logging.basicConfig(level=logging.INFO) num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4 remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts) @@ -92,6 +93,13 @@ def do_run(remote_fn) -> _TpuRunResult: except Exception: logger.exception("Failed to kill job after primary failure") return _handle_ray_error(info, e) + except Exception as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return TpuFailed(info, e) return do_run.remote(remote_fn) @@ -144,12 +152,12 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re out = ray.get(run_on_pod(remote_fn, tpu_type)) except ray.exceptions.RayTaskError as e: problem = e - if "preempted" in str(e): + if "preempted" in str(e).lower(): num_preemptions += 1 logger.warning(f"Preempted {num_preemptions} times, {e}") else: num_failures += 1 - logger.warning(f"Failed {num_failures} times") + logger.warning(f"Failed {num_failures} times", exc_info=e) continue except Exception as e: problem = e