Skip to content

Commit

Permalink
ENH(net,#18): ignore UN-SATISFIABLE operations with partial inputs
Browse files Browse the repository at this point in the history
+ The x2 TCs added just before are now passing.
  • Loading branch information
ankostis committed Sep 30, 2019
1 parent f316494 commit 1967995
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from io import StringIO

from .base import Operation
from .modifiers import optional


class DataPlaceholderNode(str):
Expand Down Expand Up @@ -141,6 +142,65 @@ def compile(self):
raise TypeError("Unrecognized network graph node")


def _collect_satisfiable_needs(self, operation, inputs, satisfiables, visited):
"""
Recusrively check if operation inputs are given/calculated (satisfied), or not.
:param satisfiables:
the set to populate with satisfiable operations
:param visited:
a cache of operations & needs, not to visit them again
:return:
true if opearation is satisfiable
"""
assert isinstance(operation, Operation), (
"Expected Operation, got:",
type(operation),
)

if operation in visited:
return visited[operation]


def is_need_satisfiable(need):
if need in visited:
return visited[need]

if need in inputs:
satisfied = True
else:
need_providers = list(self.graph.predecessors(need))
satisfied = bool(need_providers) and any(
self._collect_satisfiable_needs(op, inputs, satisfiables, visited)
for op in need_providers
)
visited[need] = satisfied

return satisfied

satisfied = all(
is_need_satisfiable(need)
for need in operation.needs
if not isinstance(need, optional)
)
if satisfied:
satisfiables.add(operation)
visited[operation] = satisfied

return satisfied


def _collect_satisfiable_operations(self, nodes, inputs):
satisfiables = set()
visited = {}
for node in nodes:
if node not in visited and isinstance(node, Operation):
self._collect_satisfiable_needs(node, inputs, satisfiables, visited)

return satisfiables


def _find_necessary_steps(self, outputs, inputs):
"""
Determines what graph steps need to pe run to get to the requested
Expand Down Expand Up @@ -204,6 +264,13 @@ def _find_necessary_steps(self, outputs, inputs):
# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes

# Drop (un-satifiable) operations with partial inputs.
# See https://github.com/yahoo/graphkit/pull/18
#
satisfiables = self._collect_satisfiable_operations(necessary_nodes, inputs)
for node in list(necessary_nodes):
if isinstance(node, Operation) and node not in satisfiables:
necessary_nodes.remove(node)

necessary_steps = [step for step in self.steps if step in necessary_nodes]

Expand Down

0 comments on commit 1967995

Please sign in to comment.