diff --git a/src/tiledb/cloud/dag/dag.py b/src/tiledb/cloud/dag/dag.py index 1ac16f831..67573917a 100644 --- a/src/tiledb/cloud/dag/dag.py +++ b/src/tiledb/cloud/dag/dag.py @@ -14,6 +14,7 @@ Deque, Dict, FrozenSet, + Hashable, List, Optional, Sequence, @@ -92,6 +93,29 @@ def __init__(self, cause: BaseException, node: "Node"): class Node(futures.FutureLike[_T]): + """Representation of a function to run in a DAG. + + :param func: Function to run as UDF task. + :param *args: Positional arguments to pass to UDF. + :param name: Human-readable name of Node task. + :param dag: DAG this node is associated with. + :param mode: Mode the Node is to run in. + :param expand_node_output: Node to expand processes upon. + :param _download_results: An optional boolean to override default + result-downloading behavior. If True, will always download the + results of the function immediately upon completion. + If False, will not download the results of the function immediately, + but will be downloaded when ``.result()`` is called. + :param _internal_prewrapped_func: For internal use only. A function that returns. + something that is already a Result, which does not require wrapping. + We assume that all prewrapped functions make server calls. + :param _internal_accepts_stored_params: For internal use only. + Applies only when ``_prewrapped_func`` is used. + ``True`` if ``_prewrapped_func`` can accept stored parameters. + ``False`` if it cannot, and all parameters must be serialized. + :param **kwargs: Keyword arguments to pass to UDF. + """ + def __init__( self, func: Callable[..., _T], @@ -103,32 +127,12 @@ def __init__( _download_results: Optional[bool] = None, _internal_prewrapped_func: Callable[..., "results.Result[_T]"] = None, _internal_accepts_stored_params: bool = True, - **kwargs, - ): - """ - Node is a class that represents a function to run in a DAG - :param func: function to run - :param args: tuple of arguments to run - :param name: optional name of dag - :param dag: dag this node is associated with - :param mode: Mode the node is to run in. - :param _download_results: An optional boolean to override default - result-downloading behavior. If True, will always download the - results of the function immediately upon completion. - If False, will not download the results of the function immediately, - but will be downloaded when ``.result()`` is called. - :param _prewrapped_func: For internal use only. A function that returns - something that is already a Result, which does not require wrapping. - We assume that all prewrapped functions make server calls. - :param _accepts_stored_params: For internal use only. - Applies only when ``_prewrapped_func`` is used. - ``True`` if ``_prewrapped_func`` can accept stored parameters. - ``False`` if it cannot, and all parameters must be serialized. - :param kwargs: dictionary for keyword arguments - """ + **kwargs: Any, + ) -> None: self.id = uuid.uuid4() + """UUID for Node instance.""" self._name = name - + """Name of Node instances.""" self._lifecycle_condition = threading.Condition(threading.Lock()) self._status = Status.NOT_STARTED self._starting = False @@ -139,7 +143,9 @@ def __init__( self._cb_list: List[Callable[["Node[_T]"], None]] = [] self.dag = dag + """DAG this Node is pinned to.""" self.mode: Mode = mode + """Processing mode of Node.""" self._expand_node_output: Optional[Node] = expand_node_output self._resource_class = kwargs.pop("resource_class", None) @@ -169,34 +175,38 @@ def __init__( self._download_results = _download_results self.parents: Dict[uuid.UUID, Node] = {} + """Parent Nodes (Nodes this Node is dependent on).""" self.children: Dict[uuid.UUID, Node] = {} - + """Child Nodes (Nodes dependent on this Node).""" self._has_node_args = False self.args: Tuple[Any, ...] = args + """Positional args to pass into UDF.""" self.kwargs: Dict[str, Any] = kwargs + """Keyword args to pass into UDF.""" self._check_resources_and_mode() self._find_deps() - def __hash__(self): + def __hash__(self) -> Hashable: return hash(self.id) - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) == type(other) and self.id == other.id - def __ne__(self, other): + def __ne__(self, other) -> bool: return not (self == other) @property def name(self) -> str: + """The human-readable name of Node.""" return self._name or str(self.id) @name.setter def name(self, to: Optional[str]) -> None: self._name = to - def _check_resources_and_mode(self): - """ - Check if the user has set the resource options correctly for the mode + def _check_resources_and_mode(self) -> None: + """Check if the user has set the resource options correctly + for the mode. """ resources_set = self._resources is not None @@ -220,8 +230,9 @@ def _check_resources_and_mode(self): "Resource class cannot be set for locally-executed nodes." ) - def _find_deps(self): + def _find_deps(self) -> None: """Finds Nodes this depends on and adds them to our dependency list.""" + parents = _find_parent_nodes((self.args, self.kwargs)) for dep in parents: self._has_node_args = True @@ -241,13 +252,13 @@ def _find_deps(self): " in a subsequent DAG if they are already complete." ) from e - def depends_on(self, node: "Node"): - """ - Create dependency chain for node, useful when there is a dependency - that does not rely directly on passing results from one to another + def depends_on(self, node: "Node") -> None: + """Create dependency chain for node, useful when there is a dependency + that does not rely directly on passing results from one to another. + :param node: node to mark as a dependency of this node - :return: """ + with self._lifecycle_condition: if self._status is not Status.NOT_STARTED: raise RuntimeError("Cannot add dependency to an already-started node.") @@ -260,6 +271,8 @@ def depends_on(self, node: "Node"): # FutureLike methods. def cancel(self) -> bool: + """Cancel Node.""" + with self._lifecycle_condition: if self._status is not Status.NOT_STARTED: return False @@ -274,14 +287,24 @@ def done(self) -> bool: return self._done() def cancelled(self) -> bool: + """Whether Node is cancelled.""" + with self._lifecycle_condition: return self._status is Status.CANCELLED def running(self) -> bool: + """Whether Node is actively running.""" + with self._lifecycle_condition: return self._status is Status.RUNNING def result(self, timeout: Optional[float] = None) -> _T: + """Fetch Node return. + + :param timeout: Time to wait to fetch result. + :return: Results of Node processing. + """ + if self.mode == Mode.BATCH: with self._lifecycle_condition: self._wait(timeout) @@ -303,6 +326,8 @@ def result(self, timeout: Optional[float] = None) -> _T: return result.get() def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]: + """Return execption if one was raised.""" + with self._lifecycle_condition: self._wait(timeout) if self._lifecycle_exception: @@ -310,6 +335,10 @@ def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]: return self._exception def add_done_callback(self, fn: Callable[["Node[_T]"], None]) -> None: + """Add callback function to execute at Node completion. + + :param fn: Callaback to execute.""" + with self._lifecycle_condition: self._cb_list.append(fn) if not self._done(): @@ -330,14 +359,20 @@ def future(self) -> futures.FutureLike[_T]: @property def status(self) -> st.Status: + """Node status.""" + with self._lifecycle_condition: return self._status @property def error(self) -> Optional[Exception]: + """Return Node error if encountered.""" + return self._error() def retry(self) -> bool: + """Retry Node.""" + if not self.dag: return False with self._lifecycle_condition: @@ -349,8 +384,9 @@ def retry(self) -> bool: def task_id(self) -> Optional[uuid.UUID]: """Gets the server-side Task ID of this node. - Returns None if this has no task ID (as it was run on the client side). + :return: None if this has no task ID (as it was run on the client side). """ + try: with self._lifecycle_condition: if not self._result: @@ -362,12 +398,13 @@ def task_id(self) -> Optional[uuid.UUID]: return None return sp.task_id - def wait(self, timeout: Optional[float] = None): - """ - Wait for node to be completed - :param timeout: optional timeout in seconds to wait for DAG to be completed - :return: None or raises TimeoutError if timeout occurs + def wait(self, timeout: Optional[float] = None) -> None: + """Wait for node to be completed. + + :param timeout: optional timeout in seconds to wait for DAG to be completed. + :return: None or raises TimeoutError if timeout occurs. """ + with self._lifecycle_condition: self._wait(timeout) @@ -390,7 +427,7 @@ def _reset_internal(self) -> bool: def _callbacks(self): return tuple(self._cb_list) - def _wait(self, timeout: Optional[float]): + def _wait(self, timeout: Optional[float]) -> None: futures.wait_for(self._lifecycle_condition, self._done, timeout) def _done(self) -> bool: @@ -576,6 +613,24 @@ def _to_log_metadata(self) -> rest_api.TaskGraphNodeMetadata: class DAG: + """Low-level API for creating and managing direct acyclic graphs + as TileDB Cloud Task Graphs. + + :param max_workers: Number of works to allocate to execute DAG. + :param use_processes: If true will use processes instead of threads. + :param done_callback: Optional call back function to register for + when dag is completed. Function will be passed reference to this DAG. + :param update_callback: Optional call back function to register for + when dag status is updated. Function will be passed reference to this DAG. + :param namespace: Namespace to execute DAG in. + :param name: Human-readable name for DAG to be showin in Task Graph logs. + :param mode: Mode the DAG is to run in, valid options are: Mode.REALTIME, + Mode.BATCH. + :param retry_strategy: K8S retry policy to be applied to each Node. + :param workflow_retry_strategy: K8S retry policy to be applied to DAG. + :param deadline: Duration (sec) DAG allowed to execute before timeout. + """ + def __init__( self, max_workers: Optional[int] = None, @@ -588,52 +643,40 @@ def __init__( retry_strategy: Optional[models.RetryStrategy] = None, workflow_retry_strategy: Optional[models.RetryStrategy] = None, deadline: Optional[int] = None, - ): - """ - DAG is a class for creating and managing direct acyclic graphs - :param max_workers: how many workers should be used to execute the dag - :param use_processes: if true will use processes instead of threads, - defaults to threads - :param done_callback: optional call back function to register for - when dag is completed. Function will be passed reference to this dag - :param update_callback: optional call back function to register for - when dag status is updated. Function will be passed reference to this dag - :param namespace: optional namespace to use for all tasks in DAG - :param name: A human-readable name used to identify this task graph - in logs. Does not need to be unique. - :param mode: Mode the DAG is to run in, valid options are - Mode.REALTIME, Mode.BATCH - :param retry_strategy: RetryStrategy to be applied on every node of the DAG. - :param workflow_retry_strategy: RetryStrategy to use to retry the entire DAG. - :param deadline: Duration in seconds relative to the workflow start time - which the workflow is allowed to run before it gets terminated. - """ - self.id = uuid.uuid4() + ) -> None: + self.id: uuid.UUID = uuid.uuid4() + """UUID for DAG instance.""" self.nodes: Dict[uuid.UUID, Node] = {} + """Mapping of Node UUIDs to Node instances.""" self.nodes_by_name: Dict[str, Node] = {} - - self.namespace = namespace or client.default_charged_namespace( + """Mapping of Node names to Node instances.""" + self.namespace: str = namespace or client.default_charged_namespace( required_action=rest_api.NamespaceActions.RUN_JOB ) - self.name = name + """Namespace to execute DAG in.""" + self.name: Optional[str] = name + """Human-readable name for DAG to be showin in Task Graph logs.""" self.server_graph_uuid: Optional[uuid.UUID] = None - self.max_workers = max_workers - self.retry_strategy = retry_strategy - self.workflow_retry_strategy = workflow_retry_strategy - self.deadline = deadline - - self._update_batch_status_thread: Optional[threading.Thread] = None - """The thread that is updating the status of Batch execution.""" - self.mode: Mode = mode - """The server-generated UUID of this graph, used for logging. - Will be ``None`` until :meth:`initial_setup` is called. If submitting the log works, will be the UUID; otherwise, will be None. """ - + self.max_workers: Optional[int] = max_workers + """Number of works to allocate to execute DAG.""" + self.retry_strategy: Optional[models.RetryStrategy] = retry_strategy + """K8S retry policy to be applied to each Node.""" + self.workflow_retry_strategy: Optional[ + models.RetryStrategy + ] = workflow_retry_strategy + """K8S retry policy to be applied to DAG.""" + self.deadline: Optional[str] = deadline + """Duration (sec) DAG allowed to execute before timeout.""" + self._update_batch_status_thread: Optional[threading.Thread] = None + """The thread that is updating the status of Batch execution.""" + self.mode: Mode = mode + """Mode the DAG is to run in.""" self.visualization = None - + """Visualization metadata.""" self._udf_executor: futures.Executor """The executor that is used to make server calls and run local UDFs.""" if use_processes: @@ -648,15 +691,17 @@ def __init__( max_workers=max_workers, ) """The thread pool that is used to execute nodes' exec functions.""" - self._lifecycle_condition = threading.Condition(threading.Lock()) - self.completed_nodes: Dict[uuid.UUID, Node] = {} + """Completed Nodes.""" self.failed_nodes: Dict[uuid.UUID, Node] = {} + """Failed Nodes.""" self.running_nodes: Dict[uuid.UUID, Node] = {} + """Running Nodes.""" self.not_started_nodes: Dict[uuid.UUID, Node] = {} + """Queued Nodes.""" self.cancelled_nodes: Dict[uuid.UUID, Node] = {} - + """Cancelled Nodes.""" self._status = st.Status.NOT_STARTED self._done_callbacks = [] @@ -669,25 +714,28 @@ def __init__( self._tried_setup: bool = False - def __hash__(self): + def __hash__(self) -> Hashable: return hash(self.id) - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) == type(other) and self.id == other.id - def __ne__(self, other): + def __ne__(self, other) -> bool: return not (self == other) @property def status(self): + """Get DAG status.""" + with self._lifecycle_condition: return self._status - def initial_setup(self): + def initial_setup(self) -> uuid.UUID: """Performs one-time server-side setup tasks. Can safely be called multiple times. """ + with self._lifecycle_condition: if not self._tried_setup: log_structure = self._build_log_structure() @@ -721,7 +769,6 @@ def add_update_callback(self, func): Add a callback for when DAG status is updated :param func: Function to call when DAG status is updated. The function will be passed reference to this dag - :return: """ if not callable(func): raise TypeError("func to add_update_callback must be callable") @@ -734,7 +781,6 @@ def add_done_callback(self, func): Add a callback for when DAG is completed :param func: Function to call when DAG status is updated. The function will be passed reference to this dag - :return: """ if not callable(func): raise TypeError("func to add_done_callback must be callable") @@ -793,17 +839,19 @@ def done(self) -> bool: with self._lifecycle_condition: return self._done() - def add_node_obj(self, node): - """ - Add node to DAG + def add_node_obj(self, node) -> Node: + """Add node to DAG. + :param node: to add to dag - :return: node + :return: Node instance. """ + with self._lifecycle_condition: return self._add_node_internal(node) def _add_node_internal(self, node: Node) -> Node: """Add node implementation. Must hold lifecycle condition.""" + if self._status is not Status.NOT_STARTED: raise RuntimeError("Cannot add nodes to a running graph") self.nodes[node.id] = node @@ -815,8 +863,7 @@ def _add_node_internal(self, node: Node) -> Node: return node def add_node(self, func_exec, *args, name=None, local_mode=True, **kwargs): - """ - Create and add a node. + """Create and add a node. DEPRECATED. Use `submit_local` instead. @@ -825,6 +872,7 @@ def add_node(self, func_exec, *args, name=None, local_mode=True, **kwargs): :param name: name :return: Node that is created """ + mode = Mode.LOCAL if local_mode else Mode.REALTIME return self._add_raw_node( @@ -904,14 +952,15 @@ def _add_prewrapped_node( ) return self._add_node_internal(node) - def submit_array_udf(self, func, *args, **kwargs): - """ - Submit a function that will be executed in the cloud serverlessly - :param func: function to execute - :param args: arguments for function execution - :param name: name - :return: Node that is created + def submit_array_udf(self, func: Callable, *args: Any, **kwargs: Any): + """Submit a function that will be executed in the cloud serverlessly. + + :param func: Function to execute in UDF task. + :param *args: Postional arguments to pass into Node instantation. + :param **kwargs: Keyword args to pass into Node instantiation. + :return: Node that is created. """ + return self._add_prewrapped_node( array.apply_base, func, @@ -920,24 +969,25 @@ def submit_array_udf(self, func, *args, **kwargs): **kwargs, ) - def submit_local(self, func, *args, **kwargs): - """ - Submit a function that will run locally - :param func: function to execute - :param args: arguments for function execution - :param name: name + def submit_local(self, func: Callable, *args: Any, **kwargs): + """Submit a function that will run locally. + + :param func: Function to execute in UDF task. + :param *args: Postional arguments to pass into Node instantation. + :param **kwargs: Keyword args to pass into Node instantiation. :return: Node that is created """ + kwargs.setdefault("name", functions.full_name(func)) return self._add_raw_node(func, *args, mode=Mode.LOCAL, **kwargs) - def submit_udf(self, func, *args, **kwargs): - """ - Submit a function that will be executed in the cloud serverlessly - :param func: function to execute - :param args: arguments for function execution - :param name: name - :return: Node that is created + def submit_udf(self, func: Callable, *args, **kwargs): + """Submit a function that will be executed in the cloud serverlessly. + + :param func: Function to execute in UDF task. + :param *args: Postional arguments to pass into Node instantation. + :param **kwargs: Keyword args to pass into Node instantiation. + :return: Node that is created. """ if "local_mode" in kwargs: @@ -960,16 +1010,38 @@ def submit_udf(self, func, *args, **kwargs): submit = submit_udf def submit_udf_stage( - self, func, *args, expand_node_output: Optional[Node] = None, **kwargs - ): - """ - Submit a function that will be executed in the cloud serverlessly - :param func: function to execute - :param args: arguments for function execution - :param expand_node_output: the Node that we want to expand the output of. + self, + func: Callable, + *args: Any, + expand_node_output: Optional[Node] = None, + **kwargs: Any, + ) -> Node: + """Submit a function that will be executed in the cloud serverlessly. + + Expand on node output simply means to dynamically allocate works to this UDF + stage based on the output of the node indicated via the `expand_node_output` + arg. + + For example, if a node, `NodeA` (`NodeA = DAG.submit(...)`), returns a list + of str values and `NodeA` is passed to `expand_node_output`, along with an + arg in the `func` passed to `submit_udf_stage` that accepts a str is also + passed `NodeA`, a node will spawn in parallel for each str value in the + result of `NodeA`. + + ```python + graph = DAG(...) + + NodeA = graph.submit() + + NodeB = graph.submit_udf_stage(..., expand_node_output=NodeA, str_arg=NodeA) + ``` + + :param func: Function to execute in UDF task. + :param *args: Postional arguments to pass into Node instantation. + :param expand_node_output: Node that we want to expand the output of. The output of the node should be a JSON encoded list. - :param name: name - :return: Node that is created + :param **kwargs: Keyword args to pass into Node instantiation. + :return: Node that is created. """ if "local_mode" in kwargs or self.mode != Mode.BATCH: @@ -986,14 +1058,15 @@ def submit_udf_stage( **kwargs, ) - def submit_sql(self, *args, **kwargs): - """ - Submit a sql query to run serverlessly in the cloud - :param sql: query to execute - :param args: arguments for function execution - :param name: name + def submit_sql(self, *args: Any, **kwargs: Any) -> Node: + """Submit a sql query to run serverlessly in the cloud. + + :param sql: Query to execute. + :param *args: Postional arguments to pass into Node instantation. + :param **kwargs: Keyword args to pass into Node instantiation. :return: Node that is created """ + return self._add_prewrapped_node( _sql_exec.exec_base, *args, @@ -1024,12 +1097,12 @@ def report_node_status_change(self, node: Node, new_status: Status): elif new_status is Status.CANCELLED: self.cancelled_nodes[node.id] = node - def report_node_complete(self, node: Node): - """ - Report a node as complete - :param node: to mark as complete - :return + def report_node_complete(self, node: Node) -> None: + """Report a node as complete. + + :param node: Node to mark as complete. """ + to_report: Optional[str] = None """The client-side event to report to the server, if needed. @@ -1039,6 +1112,7 @@ def report_node_complete(self, node: Node): the graph information on the server can know about these failed nodes even though there will be no ArrayTask for them in our database. """ + with self._lifecycle_condition: # A node may be either "running" or "not started" depending upon # whether it failed or was cancelled. @@ -1127,11 +1201,12 @@ def _set_status(self, st: Status) -> None: self._status = st self._lifecycle_condition.notify_all() - def _find_root_nodes(self): - """ - Find all root nodes - :return: list of root nodes + def _find_root_nodes(self) -> List[Node]: + """Find all root nodes. + + :return: list of root nodes. """ + roots = [] for node in self.nodes.values(): if node.parents is None or len(node.parents) == 0: @@ -1139,18 +1214,17 @@ def _find_root_nodes(self): return roots - def _find_leaf_nodes(self): - """ - Find all leaf nodes + def _find_leaf_nodes(self) -> List[Node]: + """Find all leaf nodes. + :return: list of leaf nodes """ + return [n for n in self.nodes.values() if not n.children] - def compute(self): - """ - Start the DAG by executing root nodes - :return: - """ + def compute(self) -> None: + """Start the DAG by executing root nodes.""" + with self._lifecycle_condition: if self._status is not Status.NOT_STARTED: return @@ -1191,8 +1265,8 @@ def _maybe_exec(self, node: Node): del self.not_started_nodes[node.id] def wait(self, timeout: Optional[float] = None) -> None: - """ - Wait for DAG to be completed + """Wait for DAG to be completed. + :param timeout: optional timeout in seconds to wait for DAG to be completed :return: None or raises TimeoutError if timeout occurs """ @@ -1211,7 +1285,9 @@ def wait(self, timeout: Optional[float] = None) -> None: assert exc raise exc - def cancel(self): + def cancel(self) -> None: + """Cancel DAG.""" + if self.mode == Mode.BATCH: client.build(rest_api.TaskGraphLogsApi).stop_task_graph_execution( namespace=self.namespace, id=self.server_graph_uuid @@ -1230,6 +1306,7 @@ def cancel(self): def retry_all(self) -> None: """Retries all failed and cancelled nodes.""" + if self.mode == Mode.BATCH: for node in frozenset(self.failed_nodes.values()).union( self.cancelled_nodes.values() @@ -1263,9 +1340,9 @@ def retry_all(self) -> None: for n in to_retry: n.retry() - def find_end_nodes(self): - """ - Find all end nodes + def find_end_nodes(self) -> List[Node]: + """Find all end nodes. + :return: list of end nodes """ end = [] @@ -1275,7 +1352,12 @@ def find_end_nodes(self): return end - def stats(self): + def stats(self) -> Dict[str, Union[int, float]]: + """Get DAG node statistics. + + :return: All node stats. + """ + return { "percent_complete": len(self.completed_nodes) / len(self.nodes) * 100, "running": len(self.running_nodes), @@ -1299,11 +1381,12 @@ def networkx_graph(self): return graph - def get_tiledb_plot_node_details(self): - """ - Build list of details needed for tiledb node graph - :return: + def get_tiledb_plot_node_details(self) -> Dict[str, Dict[str, str]]: + """Build list of details needed for tiledb node graph + + :return: Node summary """ + node_details = {} for node in self.nodes.values(): @@ -1338,14 +1421,14 @@ def _update_dag_plotly_graph(graph): ) def visualize(self, notebook=True, auto_update=True, force_plotly=False): - """ - Build and render a tree diagram of the DAG. + """Build and render a tree diagram of the DAG. + :param notebook: Is the visualization inside a jupyter notebook? - If so we'll use a widget - :param auto_update: Should the diagram be auto updated with each status change + If so we'll use a widget. + :param auto_update: Should the diagram be auto updated with each status change. :param force_plotly: Force the use of plotly graphs instead of - TileDB Plot Widget - :return: returns figure + TileDB Plot Widget. + :return: returns figure. """ if not notebook or force_plotly: return self._visualize_plotly(notebook=notebook, auto_update=auto_update) @@ -1385,7 +1468,7 @@ def _visualize_tiledb(self, auto_update=True): return fig def _visualize_plotly(self, notebook=True, auto_update=True): - """ + """Visualize figure. :param notebook: Is the visualization inside a jupyter notebook? If so we'll use a widget @@ -1471,7 +1554,7 @@ def _visualize_plotly(self, notebook=True, auto_update=True): def end_nodes(self): """ - Find all ends nodes + Find all end nodes dag = DAG() dag.add_node(Node()) @@ -1540,7 +1623,7 @@ def _build_batch_taskgraph(self): func = node_args.pop(0) kwargs["executable_code"] = codecs.PickleCodec.encode_base64(func) kwargs["source_text"] = functions.getsourcelines(func) - if type(node.args[0]) == str: + if isinstance(node.args[0], str): func = node_args.pop(0) kwargs["registered_udf_name"] = func