diff --git a/python/ray/dashboard/utils.py b/python/ray/dashboard/utils.py index 3efd441039bc5..33388b2885553 100644 --- a/python/ray/dashboard/utils.py +++ b/python/ray/dashboard/utils.py @@ -206,32 +206,23 @@ def to_google_style(d): def message_to_dict(message, decode_keys=None, **kwargs): """Convert protobuf message to Python dict.""" - def _decode_keys(d): - for k, v in d.items(): - if isinstance(v, dict): - d[k] = _decode_keys(v) - if isinstance(v, list): - new_list = [] - for i in v: - if isinstance(i, dict): - new_list.append(_decode_keys(i)) - else: - new_list.append(i) - d[k] = new_list - else: - if k in decode_keys: - d[k] = binary_to_hex(b64decode(v)) - else: - d[k] = v - return d - d = ray._private.protobuf_compat.message_to_dict( message, use_integers_for_enums=False, **kwargs ) - if decode_keys: - return _decode_keys(d) - else: - return d + + def _decode_rec(o, should_decode=False): + if isinstance(o, dict): + for k, v in o.items(): + o[k] = _decode_rec(v, should_decode=k in decode_keys) + return o + elif isinstance(o, list): + return [_decode_rec(i, should_decode) for i in o] + elif should_decode: + return binary_to_hex(b64decode(o)) + else: + return o + + return _decode_rec(d) if decode_keys else d class SignalManager: diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 908faefcd9fd1..2fef31cb4e1ee 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -23,7 +23,7 @@ import ray.dashboard.consts as dashboard_consts import ray._private.state as global_state import ray._private.ray_constants as ray_constants -from ray._raylet import ActorID +from ray._raylet import ActorID, ObjectRef from ray._private.test_utils import ( run_string_as_driver, wait_for_condition, @@ -1015,6 +1015,7 @@ async def test_api_manager_list_tasks_events(state_api_manager): data_source_client.get_all_task_info = AsyncMock() id = b"1234" func_or_class = "f" + arg_ref = ObjectRef.from_random() # Generate a task event. @@ -1023,8 +1024,10 @@ async def test_api_manager_list_tasks_events(state_api_manager): name=func_or_class, func_or_class_name=func_or_class, type=TaskType.NORMAL_TASK, + args_object_ids=[arg_ref.binary()], ) - current = time.time_ns() + + current = 0 second = int(1e9) state_updates = TaskStateUpdate( node_id=node_id.binary(), @@ -1048,31 +1051,42 @@ async def test_api_manager_list_tasks_events(state_api_manager): ) data_source_client.get_all_task_info.side_effect = [generate_task_data([events])] result = await state_api_manager.list_tasks(option=create_api_options(detail=True)) - result = result.result[0] - assert "events" in result - assert result["state"] == "FINISHED" - expected_events = [ - { - "state": "PENDING_ARGS_AVAIL", - "created_ms": current // 1e6, - }, - { - "state": "SUBMITTED_TO_WORKER", - "created_ms": (current + second) // 1e6, - }, - { - "state": "RUNNING", - "created_ms": (current + 2 * second) // 1e6, - }, - { - "state": "FINISHED", - "created_ms": (current + 3 * second) // 1e6, - }, - ] - for actual, expected in zip(result["events"], expected_events): - assert actual == expected - assert result["start_time_ms"] == (current + 2 * second) // 1e6 - assert result["end_time_ms"] == (current + 3 * second) // 1e6 + + assert { + "state": "FINISHED", + "type": "NORMAL_TASK", + "name": "f", + "error_message": None, + "events": [ + {"state": "PENDING_ARGS_AVAIL", "created_ms": 0.0}, + {"state": "SUBMITTED_TO_WORKER", "created_ms": 1000.0}, + {"state": "RUNNING", "created_ms": 2000.0}, + {"state": "FINISHED", "created_ms": 3000.0}, + ], + "runtime_env_info": None, + "end_time_ms": 3000.0, + "job_id": "30303031", + "error_type": None, + "func_or_class_name": "f", + "attempt_number": 0, + "node_id": node_id.hex(), + "required_resources": {}, + "worker_pid": None, + "language": "PYTHON", + "placement_group_id": None, + "creation_time_ms": 0.0, + "worker_id": None, + "task_log_info": None, + "profiling_data": {}, + "actor_id": None, + "is_debugger_paused": None, + "start_time_ms": 2000.0, + "task_id": "31323334", + "parent_task_id": "", + "args_object_ids": [ + arg_ref.hex(), + ], + } == result.result[0] """ Test only start_time_ms is updated. diff --git a/python/ray/util/state/common.py b/python/ray/util/state/common.py index 686d5355af4a0..93dbb2fda208e 100644 --- a/python/ray/util/state/common.py +++ b/python/ray/util/state/common.py @@ -760,6 +760,9 @@ class TaskState(StateSchema): error_message: Optional[str] = state_column(detail=True, filterable=False) # Is task paused by the debugger is_debugger_paused: Optional[bool] = state_column(detail=True, filterable=True) + # List of objects (passed in by ref as arguments) this task + # is dependent on + args_object_ids: List[str] = state_column(detail=True, filterable=False) @dataclass(init=not IS_PYDANTIC_2) @@ -1549,6 +1552,8 @@ def protobuf_to_task_state_dict(message: TaskEvents) -> dict: "worker_id", "placement_group_id", "component_id", + "object_id", + "args_object_ids", ], ) @@ -1579,6 +1584,7 @@ def protobuf_to_task_state_dict(message: TaskEvents) -> dict: "runtime_env_info", "parent_task_id", "placement_group_id", + "args_object_ids", ], ), (task_attempt, ["task_id", "attempt_number", "job_id"]), diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 61052f231eda1..4edc66a0a3f85 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -637,6 +637,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ [this] { RAY_LOG(INFO) << "Event stats:\n\n" << io_service_.stats().StatsString() << "\n\n" + << task_execution_service_.stats().StatsString() << "\n\n" << "-----------------\n" << "Task Event stats:\n" << task_event_buffer_->DebugString() << "\n"; diff --git a/src/ray/gcs/pb_util.h b/src/ray/gcs/pb_util.h index cb3c518072b2e..f3a792641a751 100644 --- a/src/ray/gcs/pb_util.h +++ b/src/ray/gcs/pb_util.h @@ -252,9 +252,17 @@ inline void FillTaskInfo(rpc::TaskInfoEntry *task_info, resources_map.end()); task_info->mutable_runtime_env_info()->CopyFrom(task_spec.RuntimeEnvInfo()); const auto &pg_id = task_spec.PlacementGroupBundleId().first; + if (!pg_id.IsNil()) { task_info->set_placement_group_id(pg_id.Binary()); } + + // Fill in task args + for (size_t i = 0; i < task_spec.NumArgs(); i++) { + if (task_spec.ArgByRef(i)) { + task_info->add_args_object_ids(task_spec.ArgRef(i).object_id()); + } + } } // Fill task_info for the export API with task specification from task_spec diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index fb20006d57b96..2dddf10d01c31 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -589,6 +589,13 @@ message TaskInfoEntry { // If the task/actor is created within a placement group, // this value is configured. optional bytes placement_group_id = 26; + // Tasks arguments passed in as (object) references. + // + // NOTE: This list only contains `ObjectReference`s passed in as arguments + // this task is dependent on and does NOT contain + // - Args passed by value (inlined) + // - ObjectRefs of the args passed by value + repeated bytes args_object_ids = 27; } message Bundle {