Skip to content

Commit

Permalink
merge UNASTIFIABLE + ORDERED_SETs
Browse files Browse the repository at this point in the history
  • Loading branch information
ankostis committed Oct 1, 2019
2 parents bc4c221 + 12bdfe4 commit b8377ca
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
13 changes: 8 additions & 5 deletions graphkit/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from itertools import chain

from boltons.setutils import IndexedSet as iset

from .base import Operation, NetworkOperation
from .network import Network
from .modifiers import optional
Expand All @@ -28,7 +30,7 @@ def _compute(self, named_inputs, outputs=None):

result = zip(self.provides, result)
if outputs:
outputs = set(outputs)
outputs = sorted(set(outputs))
result = filter(lambda x: x[0] in outputs, result)

return dict(result)
Expand Down Expand Up @@ -185,22 +187,23 @@ def __call__(self, *operations):

# If merge is desired, deduplicate operations before building network
if self.merge:
merge_set = set()
merge_set = iset() # Preseve given node order.
for op in operations:
if isinstance(op, NetworkOperation):
net_ops = filter(lambda x: isinstance(x, Operation), op.net.steps)
merge_set.update(net_ops)
else:
merge_set.add(op)
operations = list(merge_set)
operations = merge_set

def order_preserving_uniquifier(seq, seen=None):
seen = seen if seen else set()
seen = seen if seen else set() # unordered, not iterated
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]

provides = order_preserving_uniquifier(chain(*[op.provides for op in operations]))
needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]), set(provides))
needs = order_preserving_uniquifier(chain(*[op.needs for op in operations]),
set(provides)) # unordered, not iterated

# compile network
net = Network()
Expand Down
17 changes: 10 additions & 7 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from io import StringIO

from boltons.setutils import IndexedSet as iset

from .base import Operation
from .modifiers import optional

Expand Down Expand Up @@ -108,7 +110,7 @@ def compile(self):
self.steps = []

# create an execution order such that each layer's needs are provided.
ordered_nodes = list(nx.dag.topological_sort(self.graph))
ordered_nodes = iset(nx.topological_sort(self.graph))

# add Operations evaluation steps, and instructions to free data.
for i, node in enumerate(ordered_nodes):
Expand Down Expand Up @@ -192,7 +194,7 @@ def is_need_satisfiable(need):


def _collect_satisfiable_operations(self, nodes, inputs):
satisfiables = set()
satisfiables = set() # unordered, not iterated
visited = {}
for node in nodes:
if node not in visited and isinstance(node, Operation):
Expand Down Expand Up @@ -223,7 +225,7 @@ def _find_necessary_steps(self, outputs, inputs):
"""

# return steps if it has already been computed before for this set of inputs and outputs
outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set)) else outputs
outputs = tuple(sorted(outputs)) if isinstance(outputs, (list, set, iset)) else outputs
inputs_keys = tuple(sorted(inputs.keys()))
cache_key = (inputs_keys, outputs)
if cache_key in self._necessary_steps_cache:
Expand All @@ -235,7 +237,7 @@ def _find_necessary_steps(self, outputs, inputs):
# If caller requested all outputs, the necessary nodes are all
# nodes that are reachable from one of the inputs. Ignore input
# names that aren't in the graph.
necessary_nodes = set()
necessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
necessary_nodes |= nx.descendants(graph, input_name)
Expand All @@ -246,15 +248,15 @@ def _find_necessary_steps(self, outputs, inputs):
# are made unecessary because we were provided with an input that's
# deeper into the network graph. Ignore input names that aren't
# in the graph.
unnecessary_nodes = set()
unnecessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
unnecessary_nodes |= nx.ancestors(graph, input_name)

# Find the nodes we need to be able to compute the requested
# outputs. Raise an exception if a requested output doesn't
# exist in the graph.
necessary_nodes = set()
necessary_nodes = set() # unordered, not iterated
for output_name in outputs:
if not graph.has_node(output_name):
raise ValueError("graphkit graph does not have an output "
Expand Down Expand Up @@ -333,7 +335,7 @@ def _compute_thread_pool_barrier_method(self, named_inputs, outputs,
necessary_nodes = self._find_necessary_steps(outputs, named_inputs)

# this keeps track of all nodes that have already executed
has_executed = set()
has_executed = set() # unordered, not iterated

# with each loop iteration, we determine a set of operations that can be
# scheduled, then schedule them onto a thread pool, then collect their
Expand Down Expand Up @@ -531,6 +533,7 @@ def ready_to_schedule_operation(op, has_executed, graph):
A boolean indicating whether the operation may be scheduled for
execution based on what has already been executed.
"""
# unordered, not iterated
dependencies = set(filter(lambda v: isinstance(v, Operation),
nx.ancestors(graph, op)))
return dependencies.issubset(has_executed)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
install_requires=[
"networkx; python_version >= '3.5'",
"networkx == 2.2; python_version < '3.5'",
"boltons" # for IndexSet
],
extras_require={
'plot': ['pydot', 'matplotlib']
Expand Down

0 comments on commit b8377ca

Please sign in to comment.