diff --git a/graphkit/functional.py b/graphkit/functional.py index 65388973..9d88cdf9 100644 --- a/graphkit/functional.py +++ b/graphkit/functional.py @@ -3,6 +3,8 @@ from itertools import chain +from boltons.setutils import IndexedSet as iset + from .base import Operation, NetworkOperation from .network import Network from .modifiers import optional @@ -28,7 +30,7 @@ def _compute(self, named_inputs, outputs=None): result = zip(self.provides, result) if outputs: - outputs = set(outputs) + outputs = sorted(set(outputs)) result = filter(lambda x: x[0] in outputs, result) return dict(result) @@ -185,23 +187,27 @@ def __call__(self, *operations): # If merge is desired, deduplicate operations before building network if self.merge: - merge_set = set() + merge_set = iset() # Preseve given node order. for op in operations: if isinstance(op, NetworkOperation): net_ops = filter(lambda x: isinstance(x, Operation), op.net.steps) merge_set.update(net_ops) else: merge_set.add(op) - operations = list(merge_set) + operations = merge_set def order_preserving_uniquifier(seq, seen=None): - seen = seen if seen else set() + seen = seen if seen else set() # unordered, not iterated seen_add = seen.add return [x for x in seq if not (x in seen or seen_add(x))] provides = order_preserving_uniquifier(chain(*[op.provides for op in operations])) - needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]), set(provides)) - + needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]), + set(provides)) # unordered, not iterated + # Mark them all as optional, now that #18 calmly ignores + # non-fully satisfied operations. + needs = [optional(n) for op in operations for n in op.needs] + # compile network net = Network() for op in operations: diff --git a/graphkit/network.py b/graphkit/network.py index 0df3ddf8..cad561ee 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -5,9 +5,15 @@ import os import networkx as nx +from collections import defaultdict from io import StringIO +from itertools import chain + + +from boltons.setutils import IndexedSet as iset from .base import Operation +from .modifiers import optional class DataPlaceholderNode(str): @@ -107,7 +113,7 @@ def compile(self): self.steps = [] # create an execution order such that each layer's needs are provided. - ordered_nodes = list(nx.dag.topological_sort(self.graph)) + ordered_nodes = iset(nx.topological_sort(self.graph)) # add Operations evaluation steps, and instructions to free data. for i, node in enumerate(ordered_nodes): @@ -123,22 +129,57 @@ def compile(self): # Add instructions to delete predecessors as possible. A # predecessor may be deleted if it is a data placeholder that # is no longer needed by future Operations. - for predecessor in self.graph.predecessors(node): + for need in self.graph.pred[node]: if self._debug: - print("checking if node %s can be deleted" % predecessor) - predecessor_still_needed = False + print("checking if node %s can be deleted" % need) for future_node in ordered_nodes[i+1:]: - if isinstance(future_node, Operation): - if predecessor in future_node.needs: - predecessor_still_needed = True - break - if not predecessor_still_needed: + if isinstance(future_node, Operation) and need in future_node.needs: + break + else: if self._debug: - print(" adding delete instruction for %s" % predecessor) - self.steps.append(DeleteInstruction(predecessor)) + print(" adding delete instruction for %s" % need) + self.steps.append(DeleteInstruction(need)) else: - raise TypeError("Unrecognized network graph node") + raise TypeError("Unrecognized network graph node %s" % type(node)) + + + def _collect_unsatisfiable_operations(self, necessary_nodes, inputs): + """ + Traverse ordered graph and mark satisfied needs on each operation, + + collecting those missing at least one. + Since the graph is ordered, as soon as we're on an operation, + all its needs have been accounted, so we can get its satisfaction. + + :param necessary_nodes: + the subset of the graph to consider but WITHOUT the initial data + (because that is what :meth:`_find_necessary_steps()` can gives us...) + :param inputs: + an iterable of the names of the input values + return: + a list of unsatisfiable operations + """ + G = self.graph # shortcut + ok_data = set(inputs) # to collect producible data + op_satisfaction = defaultdict(set) # to collect operation satisfiable needs + unsatisfiables = [] # to collect operations with partial needs + # We also need inputs to mark op_satisfaction. + nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings + for node in nx.topological_sort(G.subgraph(nodes)): + if isinstance(node, Operation): + real_needs = set(n for n in node.needs if not isinstance(n, optional)) + if real_needs.issubset(op_satisfaction[node]): + # mark all future data-provides as ok + ok_data.update(G.adj[node]) + else: + unsatisfiables.append(node) + elif isinstance(node, (DataPlaceholderNode, str)) and node in ok_data: + # mark satisfied-needs on all future operations + for future_op in G.adj[node]: + op_satisfaction[future_op].add(node) + + return unsatisfiables def _find_necessary_steps(self, outputs, inputs): @@ -163,7 +204,7 @@ def _find_necessary_steps(self, outputs, inputs): """ # return steps if it has already been computed before for this set of inputs and outputs - outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs + outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set, iset)) else outputs inputs_keys = tuple(sorted(inputs.keys())) cache_key = (inputs_keys, outputs) if cache_key in self._necessary_steps_cache: @@ -175,7 +216,7 @@ def _find_necessary_steps(self, outputs, inputs): # If caller requested all outputs, the necessary nodes are all # nodes that are reachable from one of the inputs. Ignore input # names that aren't in the graph. - necessary_nodes = set() + necessary_nodes = set() # unordered, not iterated for input_name in iter(inputs): if graph.has_node(input_name): necessary_nodes |= nx.descendants(graph, input_name) @@ -186,7 +227,7 @@ def _find_necessary_steps(self, outputs, inputs): # are made unecessary because we were provided with an input that's # deeper into the network graph. Ignore input names that aren't # in the graph. - unnecessary_nodes = set() + unnecessary_nodes = set() # unordered, not iterated for input_name in iter(inputs): if graph.has_node(input_name): unnecessary_nodes |= nx.ancestors(graph, input_name) @@ -194,7 +235,7 @@ def _find_necessary_steps(self, outputs, inputs): # Find the nodes we need to be able to compute the requested # outputs. Raise an exception if a requested output doesn't # exist in the graph. - necessary_nodes = set() + necessary_nodes = set() # unordered, not iterated for output_name in outputs: if not graph.has_node(output_name): raise ValueError("graphkit graph does not have an output " @@ -204,6 +245,11 @@ def _find_necessary_steps(self, outputs, inputs): # Get rid of the unnecessary nodes from the set of necessary ones. necessary_nodes -= unnecessary_nodes + # Drop (un-satifiable) operations with partial inputs. + # See yahoo/graphkit#18 + # + unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs) + necessary_nodes -= set(unsatisfiables) necessary_steps = [step for step in self.steps if step in necessary_nodes] @@ -266,7 +312,7 @@ def _compute_thread_pool_barrier_method(self, named_inputs, outputs, necessary_nodes = self._find_necessary_steps(outputs, named_inputs) # this keeps track of all nodes that have already executed - has_executed = set() + has_executed = set() # unordered, not iterated # with each loop iteration, we determine a set of operations that can be # scheduled, then schedule them onto a thread pool, then collect their @@ -422,8 +468,8 @@ def get_node_name(a): # save plot if filename: - basename, ext = os.path.splitext(filename) - with open(filename, "w") as fh: + _basename, ext = os.path.splitext(filename) + with open(filename, "wb") as fh: if ext.lower() == ".png": fh.write(g.create_png()) elif ext.lower() == ".dot": @@ -464,6 +510,7 @@ def ready_to_schedule_operation(op, has_executed, graph): A boolean indicating whether the operation may be scheduled for execution based on what has already been executed. """ + # unordered, not iterated dependencies = set(filter(lambda v: isinstance(v, Operation), nx.ancestors(graph, op))) return dependencies.issubset(has_executed) diff --git a/setup.py b/setup.py index bd7883f4..46a69077 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,11 @@ author_email='huyng@yahoo-inc.com', url='http://github.com/yahoo/graphkit', packages=['graphkit'], - install_requires=['networkx'], + install_requires=[ + "networkx; python_version >= '3.5'", + "networkx == 2.2; python_version < '3.5'", + "boltons" # for IndexSet + ], extras_require={ 'plot': ['pydot', 'matplotlib'] }, diff --git a/test/test_graphkit.py b/test/test_graphkit.py index bd97b317..a6d4dcb3 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -5,7 +5,7 @@ import pickle from pprint import pprint -from operator import add +from operator import add, sub, floordiv, mul from numpy.testing import assert_raises import graphkit.network as network @@ -184,6 +184,38 @@ def test_pruning_raises_for_bad_output(): outputs=['sum1', 'sum3', 'sum4']) +def test_unsatisfied_operations(): + # Test that operations with partial inputs are culled and not failing. + graph = compose(name="graph")( + operation(name="add", needs=["a", "b1"], provides=["a+b1"])(add), + operation(name="sub", needs=["a", "b2"], provides=["a-b2"])(sub), + ) + + exp = {"a": 10, "b1": 2, "a+b1": 12} + assert graph({"a": 10, "b1": 2}) == exp + assert graph({"a": 10, "b1": 2}, outputs=["a+b1"]) == {"a+b1": 12} + + exp = {"a": 10, "b2": 2, "a-b2": 8} + assert graph({"a": 10, "b2": 2}) == exp + assert graph({"a": 10, "b2": 2}, outputs=["a-b2"]) == {"a-b2": 8} + +def test_unsatisfied_operations_same_out(): + # Test unsatisfied pairs of operations providing the same output. + graph = compose(name="graph")( + operation(name="mul", needs=["a", "b1"], provides=["ab"])(mul), + operation(name="div", needs=["a", "b2"], provides=["ab"])(floordiv), + operation(name="add", needs=["ab", "c"], provides=["ab_plus_c"])(add), + ) + + exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21} + assert graph({"a": 10, "b1": 2, "c": 1}) == exp + assert graph({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 21} + + exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6} + assert graph({"a": 10, "b2": 2, "c": 1}) == exp + assert graph({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 6} + + def test_optional(): # Test that optional() needs work as expected.