diff --git a/graphkit/network.py b/graphkit/network.py index 0df3ddf8..6abf53f3 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -8,6 +8,7 @@ from io import StringIO from .base import Operation +from .modifiers import optional class DataPlaceholderNode(str): @@ -141,6 +142,65 @@ def compile(self): raise TypeError("Unrecognized network graph node") + def _collect_satisfiable_needs(self, operation, inputs, satisfiables, visited): + """ + Recusrively check if operation inputs are given/calculated (satisfied), or not. + + :param satisfiables: + the set to populate with satisfiable operations + + :param visited: + a cache of operations & needs, not to visit them again + :return: + true if opearation is satisfiable + """ + assert isinstance(operation, Operation), ( + "Expected Operation, got:", + type(operation), + ) + + if operation in visited: + return visited[operation] + + + def is_need_satisfiable(need): + if need in visited: + return visited[need] + + if need in inputs: + satisfied = True + else: + need_providers = list(self.graph.predecessors(need)) + satisfied = bool(need_providers) and any( + self._collect_satisfiable_needs(op, inputs, satisfiables, visited) + for op in need_providers + ) + visited[need] = satisfied + + return satisfied + + satisfied = all( + is_need_satisfiable(need) + for need in operation.needs + if not isinstance(need, optional) + ) + if satisfied: + satisfiables.add(operation) + visited[operation] = satisfied + + return satisfied + + + def _collect_satisfiable_operations(self, nodes, inputs): + satisfiables = set() + visited = {} + for node in nodes: + if node not in visited and isinstance(node, Operation): + self._collect_satisfiable_needs(node, inputs, satisfiables, visited) + + return satisfiables + + def _find_necessary_steps(self, outputs, inputs): """ Determines what graph steps need to pe run to get to the requested @@ -204,6 +264,13 @@ def _find_necessary_steps(self, outputs, inputs): # Get rid of the unnecessary nodes from the set of necessary ones. necessary_nodes -= unnecessary_nodes + # Drop (un-satifiable) operations with partial inputs. + # See https://github.com/yahoo/graphkit/pull/18 + # + satisfiables = self._collect_satisfiable_operations(necessary_nodes, inputs) + for node in list(necessary_nodes): + if isinstance(node, Operation) and node not in satisfiables: + necessary_nodes.remove(node) necessary_steps = [step for step in self.steps if step in necessary_nodes] diff --git a/test/test_graphkit.py b/test/test_graphkit.py index bd97b317..1cb1c96a 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -5,7 +5,7 @@ import pickle from pprint import pprint -from operator import add +from operator import add, mul, floordiv from numpy.testing import assert_raises import graphkit.network as network @@ -69,6 +69,22 @@ def pow_op1(a, exponent=2): # net.plot(show=True) +def test_operations_with_partial_inputs_ignored(): + graph = compose(name="graph")( + operation(name="mul", needs=["a", "b1"], provides=["ab"])(mul), + operation(name="div", needs=["a", "b2"], provides=["ab"])(floordiv), + operation(name="add", needs=["ab", "c"], provides=["ab_plus_c"])(add), + ) + + exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21} + assert graph({"a": 10, "b1": 2, "c": 1}) == exp + assert graph({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 21} + + exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6} + assert graph({"a": 10, "b2": 2, "c": 1}) == exp + assert graph({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 6} + + def test_network_simple_merge(): sum_op1 = operation(name='sum_op1', needs=['a', 'b'], provides='sum1')(add)