diff --git a/graphkit/network.py b/graphkit/network.py index 06447271..ce74fe30 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -137,7 +137,7 @@ def show_layers(self, debug=False, ret=False): def _build_execution_plan(self, dag): """ Create the list of operation-nodes & *instructions* evaluating all - + operations & instructions needed a) to free memory and b) avoid overwritting given intermediate inputs. @@ -187,45 +187,55 @@ def _build_execution_plan(self, dag): return plan - def _collect_unsatisfiable_operations(self, necessary_nodes, inputs): + def _collect_unsatisfied_operations(self, dag, inputs): """ - Traverse ordered graph and mark satisfied needs on each operation, + Traverse topologically sorted dag to collect un-satisfied operations. + + Unsatisfied operations are those suffering from ANY of the following: - collecting those missing at least one. - Since the graph is ordered, as soon as we're on an operation, - all its needs have been accounted, so we can get its satisfaction. + - They are missing at least one compulsory need-input. + Since the dag is ordered, as soon as we're on an operation, + all its needs have been accounted, so we can get its satisfaction. - :param necessary_nodes: - the subset of the graph to consider but WITHOUT the initial data - (because that is what :meth:`compile()` can gives us...) + - Their provided outputs are not linked to any data in the dag. + An operation might not have any output link when :meth:`_solve_dag()` + has broken them, due to given intermediate inputs. + + :param dag: + the graph to consider :param inputs: an iterable of the names of the input values return: - a list of unsatisfiable operations + a list of unsatisfied operations to prune """ - G = self.graph # shortcut - ok_data = set(inputs) # to collect producible data - op_satisfaction = defaultdict(set) # to collect operation satisfiable needs - unsatisfiables = [] # to collect operations with partial needs - # We also need inputs to mark op_satisfaction. - nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings - for node in nx.topological_sort(G.subgraph(nodes)): + # To collect data that will be produced. + ok_data = set(inputs) + # To colect the map of operations --> satisfied-needs. + op_satisfaction = defaultdict(set) + # To collect the operations to drop. + unsatisfied = [] + for node in nx.topological_sort(dag): if isinstance(node, Operation): - real_needs = set(n for n in node.needs if not isinstance(n, optional)) - if real_needs.issubset(op_satisfaction[node]): - # mark all future data-provides as ok - ok_data.update(G.adj[node]) + # Prune operations that ended up providing no output. + unsatisfied.append(node) else: - unsatisfiables.append(node) + real_needs = set(n for n in node.needs if not isinstance(n, optional)) + if real_needs.issubset(op_satisfaction[node]): + # We have a satisfied operation; mark its output-data + # as ok. + ok_data.update(dag.adj[node]) + else: + # Prune operations with partial inputs. + unsatisfied.append(node) elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens if node in ok_data: # mark satisfied-needs on all future operations - for future_op in G.adj[node]: + for future_op in dag.adj[node]: op_satisfaction[future_op].add(node) else: raise AssertionError("Unrecognized network graph node %r" % node) - return unsatisfiables + return unsatisfied def _solve_dag(self, outputs, inputs): @@ -245,60 +255,56 @@ def _solve_dag(self, outputs, inputs): The inputs names of all given inputs. :return: - the subgraph comprising the solution - + the *execution plan* """ - graph = self.graph - if not outputs: + dag = self.graph - # If caller requested all outputs, the necessary nodes are all - # nodes that are reachable from one of the inputs. Ignore input - # names that aren't in the graph. - necessary_nodes = set() # unordered, not iterated - for input_name in iter(inputs): - if graph.has_node(input_name): - necessary_nodes |= nx.descendants(graph, input_name) + # Ignore input names that aren't in the graph. + graph_inputs = iset(dag.nodes) & inputs # preserve order - else: + # Scream if some requested outputs aren't in the graph. + unknown_outputs = iset(outputs) - dag.nodes + if unknown_outputs: + raise ValueError( + "Unknown output node(s) requested: %s" + % ", ".join(unknown_outputs)) + + broken_dag = dag.copy() # preserve net's graph - # If the caller requested a subset of outputs, find any nodes that - # are made unecessary because we were provided with an input that's - # deeper into the network graph. Ignore input names that aren't - # in the graph. - unnecessary_nodes = set() # unordered, not iterated - for input_name in iter(inputs): - if graph.has_node(input_name): - unnecessary_nodes |= nx.ancestors(graph, input_name) - - # Find the nodes we need to be able to compute the requested - # outputs. Raise an exception if a requested output doesn't - # exist in the graph. - necessary_nodes = set() # unordered, not iterated - for output_name in outputs: - if not graph.has_node(output_name): - raise ValueError("graphkit graph does not have an output " - "node named %s" % output_name) - necessary_nodes |= nx.ancestors(graph, output_name) - - # Get rid of the unnecessary nodes from the set of necessary ones. - necessary_nodes -= unnecessary_nodes - - # Drop (un-satifiable) operations with partial inputs. + # Break the incoming edges to all given inputs. + # + # Nodes producing any given intermediate inputs are unecessary + # (unless they are also used elsewhere). + # To discover which ones to prune, we break their incoming edges + # and they will drop out while collecting ancestors from the outputs. + for given in graph_inputs: + broken_dag.remove_edges_from(list(broken_dag.in_edges(given))) + + if outputs: + # If caller requested specific outputs, we can prune any + # unrelated nodes further up the dag. + ending_in_outputs = set() + for input_name in outputs: + ending_in_outputs.update(nx.ancestors(dag, input_name)) + broken_dag = broken_dag.subgraph(ending_in_outputs | set(outputs)) + + + # Prune (un-satifiable) operations with partial inputs. # See yahoo/graphkit#18 # - unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs) - necessary_nodes -= set(unsatisfiables) + unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs) + shrinked_dag = dag.subgraph(broken_dag.nodes - unsatisfied) - shrinked_dag = graph.subgraph(necessary_nodes) + plan = self._build_execution_plan(shrinked_dag) - return shrinked_dag + return plan def compile(self, outputs=(), inputs=()): """ - Solve dag, set the :attr:`execution_plan` and cache it. + Solve dag, set the :attr:`execution_plan`, and cache it. - See :meth:`_solve_dag()` for description + See :meth:`_solve_dag()` for detailed description. :param iterable outputs: A list of desired output names. This can also be ``None``, in which @@ -306,7 +312,7 @@ def compile(self, outputs=(), inputs=()): from one of the provided inputs. :param dict inputs: - The inputs names of all given inputs. + The input names of all given inputs. """ # return steps if it has already been computed before for this set of inputs and outputs @@ -317,8 +323,7 @@ def compile(self, outputs=(), inputs=()): if cache_key in self._cached_execution_plans: self.execution_plan = self._cached_execution_plans[cache_key] else: - dag = self._solve_dag(outputs, inputs) - plan = self._build_execution_plan(dag) + plan = self._solve_dag(outputs, inputs) # save this result in a precomputed cache for future lookup self.execution_plan = self._cached_execution_plans[cache_key] = plan @@ -345,9 +350,10 @@ def compute(self, outputs, named_inputs, method=None): assert isinstance(outputs, (list, tuple)) or outputs is None,\ "The outputs argument must be a list" - # start with fresh data cache - cache = {} - cache.update(named_inputs) + # start with fresh data cache & overwrites + cache = named_inputs.copy() + + # Build and set :attr:`execution_plan`. self.compile(outputs, named_inputs.keys()) # choose a method of execution diff --git a/test/test_graphkit.py b/test/test_graphkit.py index be4b0e86..01a7f133 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -265,12 +265,13 @@ def test_pruning_with_given_intermediate_and_asked_out(): operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add), ) - exp = {"given-1": 5, "b": 2, "given-2": 7, "a": 5, "asked": 12} + exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7} # v1.2.4 is ok assert netop({"given-1": 5, "b": 2, "given-2": 2}) == exp # FAILS # - on v1.2.4 with KeyError: 'a', # - on #18 (unsatisfied) with no result. + # FIXED on #18+#26 (new dag solver). assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")