Skip to content

Commit

Permalink
Fix graph generation when positional arguments are used (#20)
Browse files Browse the repository at this point in the history
* Fix graph generation when positional arguments are used

Signed-off-by: 1597463007 <[email protected]>

* Bump minor version

Signed-off-by: 1597463007 <[email protected]>

* Rename task graph to dict graph

Signed-off-by: 1597463007 <[email protected]>

---------

Signed-off-by: 1597463007 <[email protected]>
Co-authored-by: sharpener6 <[email protected]>
  • Loading branch information
1597463007 and sharpener6 authored Nov 19, 2024
1 parent 17ee3df commit fa65c0c
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 78 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ map_reduce_sort_recursive.to_graph(partition_counts=4).to_dot().write_png("map_r

![Map-Reduce Sort Recursive](docs/_static/map_reduce_sort_recursive.png)

Use the `to_dask` method to convert the generated graph to a Dask task graph.
Use the `to_dict` method to convert the generated graph to a dict graph.

```python
import numpy as np
from distributed import Client

with Client() as client:
client.get(map_reduce_sort.to_graph(partition_count=4).to_dask(array=np.random.rand(20)))[0]
client.get(map_reduce_sort.to_graph(partition_count=4).to_dict(array=np.random.rand(20)))[0]

# [0.06253707 0.06795382 0.11492823 0.14512393 0.20183152 0.41109117
# 0.42613798 0.45156214 0.4714821 0.54000373 0.54902451 0.62671881
Expand Down
2 changes: 1 addition & 1 deletion pargraph/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.8.4"
__version__ = "0.9.0"
28 changes: 14 additions & 14 deletions pargraph/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ def set_parallel_backend(self, backend: Backend) -> None:
"""
self.backend = backend

def get(self, dsk: Dict, keys: Any, **kwargs) -> Any:
def get(self, graph: Dict, keys: Any, **kwargs) -> Any:
"""
Compute task graph
Compute dict graph
:param dsk: dask-compatible task graph
:param graph: dict graph
:param keys: keys to compute (e.g. ``"x"``, ``["x", "y", "z"]``, etc)
:param kwargs: keyword arguments to forward to the parallel backend
:return: results in the same structure as keys
"""
keyset = set(self._flatten_iter([keys]))

# cull graph to remove any unnecessary dependencies
graphlib_graph = self._cull_graph(self._convert_dsk_to_graph(dsk), keyset)
graphlib_graph = self._cull_graph(self._get_graph_dependencies(graph), keyset)
ref_count_graph = self._create_ref_count_graph(graphlib_graph)

graph = TopologicalSorter(graphlib_graph)
graph.prepare()
topological_sorter = TopologicalSorter(graphlib_graph)
topological_sorter.prepare()

results: Dict[Hashable, Any] = {}
future_to_key: Dict[Future[Any], Hashable] = {}
Expand Down Expand Up @@ -95,14 +95,14 @@ def wait_for_completed_futures():
future_to_key.pop(done_future, None)
done_keys.append(key)

graph.done(*done_keys)
topological_sorter.done(*done_keys)
for done_key in done_keys:
dereference_key(done_key)

# while there are still unscheduled tasks
while graph.is_active():
while topological_sorter.is_active():
# get in vertices
in_keys = graph.get_ready()
in_keys = topological_sorter.get_ready()

# if there are no in-vertices, wait for a future to resolve
# IMPORTANT: we make the assumption that the graph is acyclic
Expand All @@ -111,15 +111,15 @@ def wait_for_completed_futures():
continue

for in_key in in_keys:
computation = dsk[in_key]
computation = graph[in_key]

if self._is_submittable_function_computation(computation):
future = self._submit_function_computation(computation, results, **kwargs)
future_to_key[future] = in_key
else:
result = self._evaluate_computation(computation, results)
results[in_key] = result
graph.done(in_key)
topological_sorter.done(in_key)
dereference_key(in_key)

# resolve all pending futures
Expand Down Expand Up @@ -183,8 +183,8 @@ def _evaluate_computation(cls, computation: Any, results: Dict) -> Optional[Any]
return computation

@staticmethod
def _convert_dsk_to_graph(dsk: Dict) -> Dict:
keys = set(dsk.keys())
def _get_graph_dependencies(graph: Dict) -> Dict:
keys = set(graph.keys())

def flatten(value: Any) -> Set[Any]:
# handle tasks as tuples
Expand All @@ -209,7 +209,7 @@ def flatten(value: Any) -> Set[Any]:

return set()

return {key: flatten(value) for key, value in dsk.items()}
return {key: flatten(value) for key, value in graph.items()}

@staticmethod
def _create_ref_count_graph(graph: Dict) -> Dict:
Expand Down
120 changes: 77 additions & 43 deletions pargraph/graph/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def __post_init__(self):
assert isinstance(self.value, str), f"Value must be a string; got type '{type(self.value)}'"

@staticmethod
def from_dict(data: Dict) -> "Const":
def from_json(data: Dict) -> "Const":
return Const(**data)

def to_dict(self) -> Dict[str, Any]:
def to_json(self) -> Dict[str, Any]:
return {"type": self.type, "value": self.value}

@staticmethod
Expand Down Expand Up @@ -234,7 +234,7 @@ def __post_init__(self):
), f"Arg '{arg}' must ConstKey, InputKey, or NodeOutputKey; got type '{type(arg)}'"

@staticmethod
def from_dict(data: Dict) -> "FunctionCall":
def from_json(data: Dict) -> "FunctionCall":
data = data.copy()
function = data.pop("function")
return FunctionCall(
Expand All @@ -247,7 +247,7 @@ def from_dict(data: Dict) -> "FunctionCall":
**data,
)

def to_dict(self) -> Dict[str, Any]:
def to_json(self) -> Dict[str, Any]:
return {
"function": (
base64.b64encode(cloudpickle.dumps(self.function)).decode("ascii")
Expand Down Expand Up @@ -279,16 +279,16 @@ def __post_init__(self):
), f"Arg '{arg}' must ConstKey, InputKey, or NodeOutputKey; got type '{type(arg)}'"

@staticmethod
def from_dict(data: Dict) -> "GraphCall":
def from_json(data: Dict) -> "GraphCall":
data = data.copy()
return GraphCall(
graph=Graph.from_dict(data.pop("graph")),
graph=Graph.from_json(data.pop("graph")),
args={arg: _get_key_from_str(key_str) for arg, key_str in data.pop("args").items()},
**data,
)

def to_dict(self) -> Dict[str, Any]:
dct: dict = {"graph": self.graph.to_dict(), "args": {arg: key.to_str() for arg, key in self.args.items()}}
def to_json(self) -> Dict[str, Any]:
dct: dict = {"graph": self.graph.to_json(), "args": {arg: key.to_str() for arg, key in self.args.items()}}
if self.graph_name is not None:
dct["graph_name"] = self.graph_name
return dct
Expand Down Expand Up @@ -342,22 +342,22 @@ def __post_init__(self):
), f"Output '{output}' must be type '{ConstKey}', '{InputKey}', or '{NodeOutputKey}'"

@staticmethod
def from_dict(data: Dict) -> "Graph":
def from_json(data: Dict) -> "Graph":
"""
Create graph from graph dict by inferring the graph dict type
Create graph from json serializable dictionary by inferring the graph type
:param data: graph dict
:return: graph
"""
if "edges" in data:
return Graph.from_dict_with_edge_list(data)
return Graph.from_json_with_edge_list(data)

return Graph.from_dict_with_node_arguments(data)
return Graph.from_json_with_node_arguments(data)

@staticmethod
def from_dict_with_edge_list(data: Dict) -> "Graph":
def from_json_with_edge_list(data: Dict) -> "Graph":
"""
Create graph from graph dict with edge list
Create graph from json serializable dictionary with edge list
:param data: graph dict with edge list
:return: graph
Expand Down Expand Up @@ -411,53 +411,53 @@ def from_dict_with_edge_list(data: Dict) -> "Graph":

outputs[key] = new_output

return Graph.from_dict_with_node_arguments(data)
return Graph.from_json_with_node_arguments(data)

@staticmethod
def from_dict_with_node_arguments(data: Dict) -> "Graph":
def from_json_with_node_arguments(data: Dict) -> "Graph":
"""
Create graph from graph dict with node arguments
Create graph from json serializable dictionary with node arguments
:param data: graph dict with node arguments
:return: graph
"""

def _graph_node_from_dict(data: Union[Dict, str]) -> Union[FunctionCall, "GraphCall"]:
def _graph_node_from_json(data: Union[Dict, str]) -> Union[FunctionCall, "GraphCall"]:
if isinstance(data, dict) and "function" in data:
return FunctionCall.from_dict(data)
return FunctionCall.from_json(data)
elif isinstance(data, dict) and "graph" in data:
return GraphCall.from_dict(data)
return GraphCall.from_json(data)

raise ValueError(f"invalid graph node dict '{data}'")

data = data.copy()
return Graph(
consts={ConstKey(key=key): Const.from_dict(value) for key, value in data.pop("consts").items()},
consts={ConstKey(key=key): Const.from_json(value) for key, value in data.pop("consts").items()},
inputs={
InputKey(key=key): cast(ConstKey, _get_key_from_str(value)) if value is not None else None
for key, value in data.pop("inputs").items()
},
nodes={NodeKey(key=key): _graph_node_from_dict(value) for key, value in data.pop("nodes").items()},
nodes={NodeKey(key=key): _graph_node_from_json(value) for key, value in data.pop("nodes").items()},
outputs={OutputKey(key=key): _get_key_from_str(value) for key, value in data.pop("outputs").items()},
**data,
)

def to_dict(self) -> Dict[str, Any]:
def to_json(self) -> Dict[str, Any]:
"""
Convert graph representation to serializable dictionary
Convert graph representation to json serializable dictionary
:return: graph dictionary
:return: json serializable dictionary
"""
graph_dict: GraphDict = {"consts": {}, "inputs": {}, "nodes": {}, "edges": [], "outputs": {}}

for const_node_key, const_node in self.consts.items():
graph_dict["consts"][const_node_key.key] = const_node.to_dict()
graph_dict["consts"][const_node_key.key] = const_node.to_json()

for input_node_key, input_node in self.inputs.items():
graph_dict["inputs"][input_node_key.key] = input_node.to_str() if input_node is not None else None

for func_node_key, func_node in self.nodes.items():
func_node_dict = func_node.to_dict()
func_node_dict = func_node.to_json()
func_node_dict.pop("args")

graph_dict["nodes"][func_node_key.key] = func_node_dict
Expand All @@ -483,16 +483,50 @@ def to_dict(self) -> Dict[str, Any]:

return cast(dict, graph_dict)

def to_dict(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]:
"""
Convert graph to dict graph
Dict graph representation:
.. code-block:: json
{
"a": 1,
"b": 2,
"sum": (add, "a", "b")
}
Values can be:
- Tasks: represented as tuples with the format ``(fn, *args)``
- Constants: all other values
:param args: positional arguments
:param kwargs: keyword arguments
:return: dict graph and output keys
"""
inputs: dict = {**dict(zip((key.key for key in self.inputs.keys()), args)), **kwargs}
return self._convert_graph_to_dict(inputs=inputs)

def to_dask(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]:
"""
Convert graph to dask graph
.. warning::
This method is deprecated and will be removed in a future release.
Please use :func:`to_dict` instead.
:param args: positional arguments
:param kwargs: keyword arguments
:return: dask graph and output keys
"""
inputs: dict = {**dict(zip(self.inputs.keys(), args)), **kwargs}
return self._convert_graph_to_dask_graph(inputs=inputs)
warnings.warn(
"This method is deprecated and will be removed in a future release. Please use 'to_dict' instead.",
DeprecationWarning,
)
return self.to_dict(*args, **kwargs)

def to_dot(
self,
Expand Down Expand Up @@ -779,37 +813,37 @@ def _create_dot_edge(src: str, dst: str) -> pydot.Edge:

return edge

def _convert_graph_to_dask_graph(
def _convert_graph_to_dict(
self,
inputs: Optional[Dict[str, Any]] = None,
input_mapping: Optional[Dict[InputKey, str]] = None,
output_mapping: Optional[Dict[OutputKey, str]] = None,
) -> Tuple[Dict[str, Any], List[str]]:
"""
Convert our own graph format to a dask graph.
Convert our own graph format to a dict graph.
:param inputs: inputs dictionary
:param input_mapping: input mapping for subgraphs
:param output_mapping: output mapping for subgraphs
:return: tuple containing dask graph and targets
:return: tuple containing dict graph and targets
"""
assert inputs is None or input_mapping is None, "cannot specify both inputs and input_mapping"

dask_graph: dict = {}
dict_graph: dict = {}
key_to_uuid: dict = {}

# create constants
for const_key, const in self.consts.items():
graph_key = f"const_{self._get_const_label(const)}_{uuid.uuid4().hex}"
dask_graph[graph_key] = const.to_value()
dict_graph[graph_key] = const.to_value()
key_to_uuid[const_key] = graph_key

# create inputs
if inputs is not None:
for input_key in self.inputs.keys():
graph_key = f"input_{input_key.key}_{uuid.uuid4().hex}"
# if input key is not in inputs, use the default value
dask_graph[graph_key] = (
dict_graph[graph_key] = (
inputs[input_key.key] if input_key.key in inputs else self.consts[self.inputs[input_key]].to_value()
)
key_to_uuid[input_key] = graph_key
Expand Down Expand Up @@ -845,7 +879,7 @@ def _convert_graph_to_dask_graph(
else:
key_to_uuid[input_key] = key_to_uuid[const_path]

# build dask graph
# build dict graph
for node_key, node in self.nodes.items():
if isinstance(node, FunctionCall):
assert callable(node.function)
Expand All @@ -862,7 +896,7 @@ def _convert_graph_to_dask_graph(
# handle default arguments
if param_name not in node.args:
graph_key = f"const_{self._get_const_label(input_annotation.default)}_{uuid.uuid4().hex}"
dask_graph[graph_key] = input_annotation.default
dict_graph[graph_key] = input_annotation.default
args.append(graph_key)
continue

Expand All @@ -884,10 +918,10 @@ def _convert_graph_to_dask_graph(
break

constant_key = f"const_{self._get_const_label(output_position)}_{uuid.uuid4().hex}"
dask_graph[constant_key] = output_position
dask_graph[graph_key] = (_unpack_tuple, node_uuid, constant_key)
dict_graph[constant_key] = output_position
dict_graph[graph_key] = (_unpack_tuple, node_uuid, constant_key)

dask_graph[node_uuid] = (node.function,) + tuple(args)
dict_graph[node_uuid] = (node.function,) + tuple(args)

elif isinstance(node, GraphCall):
new_input_mapping = {
Expand All @@ -897,12 +931,12 @@ def _convert_graph_to_dask_graph(
output_key: key_to_uuid[NodeOutputKey(key=node_key.key, output=output_key.key)]
for output_key in node.graph.outputs
}
dask_subgraph, _ = node.graph._convert_graph_to_dask_graph(
dict_subgraph, _ = node.graph._convert_graph_to_dict(
input_mapping=new_input_mapping, output_mapping=new_output_mapping
)
dask_graph.update(dask_subgraph)
dict_graph.update(dict_subgraph)

return dask_graph, [key_to_uuid[output_path] for output_path in self.outputs.values()]
return dict_graph, [key_to_uuid[output_path] for output_path in self.outputs.values()]

def _scramble_keys(
self, old_to_new: Optional[bidict[Union[ConstKey, NodeKey], Union[ConstKey, NodeKey]]] = None
Expand Down
Loading

0 comments on commit fa65c0c

Please sign in to comment.