Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: ignore unsatified operations with partial inputs #18

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
18 changes: 12 additions & 6 deletions graphkit/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from itertools import chain

from boltons.setutils import IndexedSet as iset

from .base import Operation, NetworkOperation
from .network import Network
from .modifiers import optional
Expand All @@ -28,7 +30,7 @@ def _compute(self, named_inputs, outputs=None):

result = zip(self.provides, result)
if outputs:
outputs = set(outputs)
outputs = sorted(set(outputs))
result = filter(lambda x: x[0] in outputs, result)

return dict(result)
Expand Down Expand Up @@ -185,23 +187,27 @@ def __call__(self, *operations):

# If merge is desired, deduplicate operations before building network
if self.merge:
merge_set = set()
merge_set = iset() # Preseve given node order.
for op in operations:
if isinstance(op, NetworkOperation):
net_ops = filter(lambda x: isinstance(x, Operation), op.net.steps)
merge_set.update(net_ops)
else:
merge_set.add(op)
operations = list(merge_set)
operations = merge_set

def order_preserving_uniquifier(seq, seen=None):
seen = seen if seen else set()
seen = seen if seen else set() # unordered, not iterated
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]

provides = order_preserving_uniquifier(chain(*[op.provides for op in operations]))
needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]), set(provides))

needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]),
set(provides)) # unordered, not iterated
# Mark them all as optional, now that #18 calmly ignores
# non-fully satisfied operations.
needs = [optional(n) for op in operations for n in op.needs]

# compile network
net = Network()
for op in operations:
Expand Down
85 changes: 66 additions & 19 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
import os
import networkx as nx

from collections import defaultdict
from io import StringIO
from itertools import chain


from boltons.setutils import IndexedSet as iset

from .base import Operation
from .modifiers import optional


class DataPlaceholderNode(str):
Expand Down Expand Up @@ -107,7 +113,7 @@ def compile(self):
self.steps = []

# create an execution order such that each layer's needs are provided.
ordered_nodes = list(nx.dag.topological_sort(self.graph))
ordered_nodes = iset(nx.topological_sort(self.graph))

# add Operations evaluation steps, and instructions to free data.
for i, node in enumerate(ordered_nodes):
Expand All @@ -123,22 +129,57 @@ def compile(self):
# Add instructions to delete predecessors as possible. A
# predecessor may be deleted if it is a data placeholder that
# is no longer needed by future Operations.
for predecessor in self.graph.predecessors(node):
for need in self.graph.pred[node]:
if self._debug:
print("checking if node %s can be deleted" % predecessor)
predecessor_still_needed = False
print("checking if node %s can be deleted" % need)
for future_node in ordered_nodes[i+1:]:
if isinstance(future_node, Operation):
if predecessor in future_node.needs:
predecessor_still_needed = True
break
if not predecessor_still_needed:
if isinstance(future_node, Operation) and need in future_node.needs:
break
else:
if self._debug:
print(" adding delete instruction for %s" % predecessor)
self.steps.append(DeleteInstruction(predecessor))
print(" adding delete instruction for %s" % need)
self.steps.append(DeleteInstruction(need))

else:
raise TypeError("Unrecognized network graph node")
raise TypeError("Unrecognized network graph node %s" % type(node))


def _collect_unsatisfiable_operations(self, necessary_nodes, inputs):
"""
Traverse ordered graph and mark satisfied needs on each operation,

collecting those missing at least one.
Since the graph is ordered, as soon as we're on an operation,
all its needs have been accounted, so we can get its satisfaction.

:param necessary_nodes:
the subset of the graph to consider but WITHOUT the initial data
(because that is what :meth:`_find_necessary_steps()` can gives us...)
:param inputs:
an iterable of the names of the input values
return:
a list of unsatisfiable operations
"""
G = self.graph # shortcut
ok_data = set(inputs) # to collect producible data
op_satisfaction = defaultdict(set) # to collect operation satisfiable needs
unsatisfiables = [] # to collect operations with partial needs
# We also need inputs to mark op_satisfaction.
nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings
for node in nx.topological_sort(G.subgraph(nodes)):
if isinstance(node, Operation):
real_needs = set(n for n in node.needs if not isinstance(n, optional))
if real_needs.issubset(op_satisfaction[node]):
# mark all future data-provides as ok
ok_data.update(G.adj[node])
else:
unsatisfiables.append(node)
elif isinstance(node, (DataPlaceholderNode, str)) and node in ok_data:
# mark satisfied-needs on all future operations
for future_op in G.adj[node]:
op_satisfaction[future_op].add(node)

return unsatisfiables


def _find_necessary_steps(self, outputs, inputs):
Expand All @@ -163,7 +204,7 @@ def _find_necessary_steps(self, outputs, inputs):
"""

# return steps if it has already been computed before for this set of inputs and outputs
outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs
outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set, iset)) else outputs
inputs_keys = tuple(sorted(inputs.keys()))
cache_key = (inputs_keys, outputs)
if cache_key in self._necessary_steps_cache:
Expand All @@ -175,7 +216,7 @@ def _find_necessary_steps(self, outputs, inputs):
# If caller requested all outputs, the necessary nodes are all
# nodes that are reachable from one of the inputs. Ignore input
# names that aren't in the graph.
necessary_nodes = set()
necessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
necessary_nodes |= nx.descendants(graph, input_name)
Expand All @@ -186,15 +227,15 @@ def _find_necessary_steps(self, outputs, inputs):
# are made unecessary because we were provided with an input that's
# deeper into the network graph. Ignore input names that aren't
# in the graph.
unnecessary_nodes = set()
unnecessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
unnecessary_nodes |= nx.ancestors(graph, input_name)

# Find the nodes we need to be able to compute the requested
# outputs. Raise an exception if a requested output doesn't
# exist in the graph.
necessary_nodes = set()
necessary_nodes = set() # unordered, not iterated
for output_name in outputs:
if not graph.has_node(output_name):
raise ValueError("graphkit graph does not have an output "
Expand All @@ -204,6 +245,11 @@ def _find_necessary_steps(self, outputs, inputs):
# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes

# Drop (un-satifiable) operations with partial inputs.
# See yahoo/graphkit#18
#
unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs)
necessary_nodes -= set(unsatisfiables)

necessary_steps = [step for step in self.steps if step in necessary_nodes]

Expand Down Expand Up @@ -266,7 +312,7 @@ def _compute_thread_pool_barrier_method(self, named_inputs, outputs,
necessary_nodes = self._find_necessary_steps(outputs, named_inputs)

# this keeps track of all nodes that have already executed
has_executed = set()
has_executed = set() # unordered, not iterated

# with each loop iteration, we determine a set of operations that can be
# scheduled, then schedule them onto a thread pool, then collect their
Expand Down Expand Up @@ -422,8 +468,8 @@ def get_node_name(a):

# save plot
if filename:
basename, ext = os.path.splitext(filename)
with open(filename, "w") as fh:
_basename, ext = os.path.splitext(filename)
with open(filename, "wb") as fh:
if ext.lower() == ".png":
fh.write(g.create_png())
elif ext.lower() == ".dot":
Expand Down Expand Up @@ -464,6 +510,7 @@ def ready_to_schedule_operation(op, has_executed, graph):
A boolean indicating whether the operation may be scheduled for
execution based on what has already been executed.
"""
# unordered, not iterated
dependencies = set(filter(lambda v: isinstance(v, Operation),
nx.ancestors(graph, op)))
return dependencies.issubset(has_executed)
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
author_email='[email protected]',
url='http://github.com/yahoo/graphkit',
packages=['graphkit'],
install_requires=['networkx'],
install_requires=[
"networkx; python_version >= '3.5'",
"networkx == 2.2; python_version < '3.5'",
"boltons" # for IndexSet
],
extras_require={
'plot': ['pydot', 'matplotlib']
},
Expand Down
34 changes: 33 additions & 1 deletion test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle

from pprint import pprint
from operator import add
from operator import add, sub, floordiv, mul
from numpy.testing import assert_raises

import graphkit.network as network
Expand Down Expand Up @@ -184,6 +184,38 @@ def test_pruning_raises_for_bad_output():
outputs=['sum1', 'sum3', 'sum4'])


def test_unsatisfied_operations():
# Test that operations with partial inputs are culled and not failing.
graph = compose(name="graph")(
operation(name="add", needs=["a", "b1"], provides=["a+b1"])(add),
operation(name="sub", needs=["a", "b2"], provides=["a-b2"])(sub),
)

exp = {"a": 10, "b1": 2, "a+b1": 12}
assert graph({"a": 10, "b1": 2}) == exp
assert graph({"a": 10, "b1": 2}, outputs=["a+b1"]) == {"a+b1": 12}

exp = {"a": 10, "b2": 2, "a-b2": 8}
assert graph({"a": 10, "b2": 2}) == exp
assert graph({"a": 10, "b2": 2}, outputs=["a-b2"]) == {"a-b2": 8}

def test_unsatisfied_operations_same_out():
# Test unsatisfied pairs of operations providing the same output.
graph = compose(name="graph")(
operation(name="mul", needs=["a", "b1"], provides=["ab"])(mul),
operation(name="div", needs=["a", "b2"], provides=["ab"])(floordiv),
operation(name="add", needs=["ab", "c"], provides=["ab_plus_c"])(add),
)

exp = {"a": 10, "b1": 2, "c": 1, "ab": 20, "ab_plus_c": 21}
assert graph({"a": 10, "b1": 2, "c": 1}) == exp
assert graph({"a": 10, "b1": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 21}

exp = {"a": 10, "b2": 2, "c": 1, "ab": 5, "ab_plus_c": 6}
assert graph({"a": 10, "b2": 2, "c": 1}) == exp
assert graph({"a": 10, "b2": 2, "c": 1}, outputs=["ab_plus_c"]) == {"ab_plus_c": 6}


def test_optional():
# Test that optional() needs work as expected.

Expand Down