Skip to content

Commit

Permalink
ENH(net,#18): ignore UN-SATISFIABLE operations with partial inputs
Browse files Browse the repository at this point in the history
Usefull when 2 (or more) operations provifing the same output,
and only one has fully satisfied inputs.
Before it would fail trying to evaluate the un-satisfied ones.

+ New TC added.
.
  • Loading branch information
ankostis committed Sep 29, 2019
1 parent 617e577 commit 1519baa
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
69 changes: 69 additions & 0 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from io import StringIO

from .base import Operation
from .modifiers import optional


class DataPlaceholderNode(str):
Expand Down Expand Up @@ -141,6 +142,67 @@ def compile(self):
raise TypeError("Unrecognized network graph node")


def _collect_satisfiable_needs(self, operation, inputs, satisfiables, visited):
"""
Recusrively check if operation inputs are given/calculated (satisfied), or not.
:param satisfiables:
the set to populate with satisfiable operations
:param visited:
a cache of operations & needs, not to visit them again
:return:
true if opearation is satisfiable
"""
assert isinstance(operation, Operation), (
"Expected Operation, got:",
type(operation),
)

if operation in visited:
return visited[operation]


def is_need_satisfiable(need):
if need in visited:
return visited[need]

if need in inputs:
satisfied = True
else:
need_providers = list(self.graph.predecessors(need))
satisfied = bool(need_providers) and any(
self._collect_satisfiable_needs(op, inputs, satisfiables, visited)
for op in need_providers
)
visited[need] = satisfied

return satisfied

satisfied = all(
is_need_satisfiable(need)
for need in operation.needs
if not isinstance(need, optional)
)
if satisfied:
satisfiables.add(operation)
visited[operation] = satisfied

return satisfied


def _collect_satisfiable_operations(self, nodes, inputs):
satisfiables = set()
from unittest import mock

visited = {}
for node in nodes:
if node not in visited and isinstance(node, Operation):
self._collect_satisfiable_needs(node, inputs, satisfiables, visited)

return satisfiables


def _find_necessary_steps(self, outputs, inputs):
"""
Determines what graph steps need to pe run to get to the requested
Expand Down Expand Up @@ -204,6 +266,13 @@ def _find_necessary_steps(self, outputs, inputs):
# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes

# Drop (un-satifiable) operations with partial inputs.
# See https://github.com/yahoo/graphkit/pull/18
#
satisfiables = self._collect_satisfiable_operations(necessary_nodes, inputs)
for node in list(necessary_nodes):
if isinstance(node, Operation) and node not in satisfiables:
necessary_nodes.remove(node)

necessary_steps = [step for step in self.steps if step in necessary_nodes]

Expand Down
18 changes: 17 additions & 1 deletion test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle

from pprint import pprint
from operator import add
from operator import add, mul, floordiv
from numpy.testing import assert_raises

import graphkit.network as network
Expand Down Expand Up @@ -69,6 +69,22 @@ def pow_op1(a, exponent=2):
# net.plot(show=True)


def test_operations_with_partial_inputs_ignored():
graph = compose(name="graph")(
operation(name="mul", needs=["a", "b1"], provides=["ab"])(mul),
operation(name="div", needs=["a", "b2"], provides=["ab"])(floordiv),
operation(name="add", needs=["ab", "c"], provides=["ab_plus_c"])(add),
)

exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21}
assert graph({"a": 10, "b1": 2, "c": 1}) == exp
assert graph({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 21}

exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6}
assert graph({"a": 10, "b2": 2, "c": 1}) == exp
assert graph({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 6}


def test_network_simple_merge():

sum_op1 = operation(name='sum_op1', needs=['a', 'b'], provides='sum1')(add)
Expand Down

0 comments on commit 1519baa

Please sign in to comment.