Skip to content

Commit

Permalink
ENH(DAG): NEW SOLVER
Browse files Browse the repository at this point in the history
+ Pruning behaves correctly also when outputs given;
  this happens by breaking incoming provide-links
  to any given intermedediate inputs.
+ Unsatisfied detection now includes those without outputs
  due to broken links (above).
+ Remove some uneeded "glue" from unsatisfied-detection code,
  leftover from previous compile() refactoring.
+ Renamed satisfiable --> satisfied.
+ Improved unknown output requested raise-message.
+ x3 TCs PASS, x1 in #24 and the first x2 in #25.
- 1x TCs in #25 still FAIL, and need "Pinning" of given-inputs
  (the operation MUST and MUST NOT run in these cases).
  • Loading branch information
ankostis committed Oct 3, 2019
1 parent 17eb2fd commit 0830b7c
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 74 deletions.
156 changes: 84 additions & 72 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def show_layers(self, debug=False, ret=False):
def _build_execution_plan(self, dag):
"""
Create the list of operation-nodes & *instructions* evaluating all
operations & instructions needed a) to free memory and b) avoid
overwritting given intermediate inputs.
:param dag:
as shrinked by :meth:`compile()`
the original dag but "shrinked", not "broken"
In the list :class:`DeleteInstructions` steps (DA) are inserted between
operation nodes to reduce the memory footprint of cached results.
Expand Down Expand Up @@ -187,45 +187,57 @@ def _build_execution_plan(self, dag):

return plan

def _collect_unsatisfiable_operations(self, necessary_nodes, inputs):
def _collect_unsatisfied_operations(self, dag, inputs):
"""
Traverse ordered graph and mark satisfied needs on each operation,
Traverse topologically sorted dag to collect un-satisfied operations.
Unsatisfied operations are those suffering from ANY of the following:
collecting those missing at least one.
Since the graph is ordered, as soon as we're on an operation,
all its needs have been accounted, so we can get its satisfaction.
- They are missing at least one compulsory need-input.
Since the dag is ordered, as soon as we're on an operation,
all its needs have been accounted, so we can get its satisfaction.
:param necessary_nodes:
the subset of the graph to consider but WITHOUT the initial data
(because that is what :meth:`compile()` can gives us...)
- Their provided outputs are not linked to any data in the dag.
An operation might not have any output link when :meth:`_solve_dag()`
has broken them, due to given intermediate inputs.
:param dag:
the graph to consider
:param inputs:
an iterable of the names of the input values
return:
a list of unsatisfiable operations
a list of unsatisfied operations to prune
"""
G = self.graph # shortcut
ok_data = set(inputs) # to collect producible data
op_satisfaction = defaultdict(set) # to collect operation satisfiable needs
unsatisfiables = [] # to collect operations with partial needs
# We also need inputs to mark op_satisfaction.
nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings
for node in nx.topological_sort(G.subgraph(nodes)):
# To collect data that will be produced.
ok_data = set(inputs)
# To colect the map of operations --> satisfied-needs.
op_satisfaction = defaultdict(set)
# To collect the operations to drop.
unsatisfied = []
for node in nx.topological_sort(dag):
if isinstance(node, Operation):
real_needs = set(n for n in node.needs if not isinstance(n, optional))
if real_needs.issubset(op_satisfaction[node]):
# mark all future data-provides as ok
ok_data.update(G.adj[node])
if not dag.adj[node]:
# Prune operations that ended up providing no output.
unsatisfied.append(node)
else:
unsatisfiables.append(node)
real_needs = set(n for n in node.needs
if not isinstance(n, optional))
if real_needs.issubset(op_satisfaction[node]):
# We have a satisfied operation; mark its output-data
# as ok.
ok_data.update(dag.adj[node])
else:
# Prune operations with partial inputs.
unsatisfied.append(node)
elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens
if node in ok_data:
# mark satisfied-needs on all future operations
for future_op in G.adj[node]:
for future_op in dag.adj[node]:
op_satisfaction[future_op].add(node)
else:
raise AssertionError("Unrecognized network graph node %r" % node)

return unsatisfiables
return unsatisfied


def _solve_dag(self, outputs, inputs):
Expand All @@ -245,68 +257,64 @@ def _solve_dag(self, outputs, inputs):
The inputs names of all given inputs.
:return:
the subgraph comprising the solution
the *execution plan*
"""
graph = self.graph
if not outputs:
dag = self.graph

# If caller requested all outputs, the necessary nodes are all
# nodes that are reachable from one of the inputs. Ignore input
# names that aren't in the graph.
necessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
necessary_nodes |= nx.descendants(graph, input_name)
# Ignore input names that aren't in the graph.
graph_inputs = iset(dag.nodes) & inputs # preserve order

else:
# Scream if some requested outputs aren't in the graph.
unknown_outputs = iset(outputs) - dag.nodes
if unknown_outputs:
raise ValueError(
"Unknown output node(s) requested: %s"
% ", ".join(unknown_outputs))

broken_dag = dag.copy() # preserve net's graph

# If the caller requested a subset of outputs, find any nodes that
# are made unecessary because we were provided with an input that's
# deeper into the network graph. Ignore input names that aren't
# in the graph.
unnecessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
unnecessary_nodes |= nx.ancestors(graph, input_name)

# Find the nodes we need to be able to compute the requested
# outputs. Raise an exception if a requested output doesn't
# exist in the graph.
necessary_nodes = set() # unordered, not iterated
for output_name in outputs:
if not graph.has_node(output_name):
raise ValueError("graphkit graph does not have an output "
"node named %s" % output_name)
necessary_nodes |= nx.ancestors(graph, output_name)

# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes

# Drop (un-satifiable) operations with partial inputs.
# Break the incoming edges to all given inputs.
#
# Nodes producing any given intermediate inputs are unecessary
# (unless they are also used elsewhere).
# To discover which ones to prune, we break their incoming edges
# and they will drop out while collecting ancestors from the outputs.
for given in graph_inputs:
broken_dag.remove_edges_from(list(broken_dag.in_edges(given)))

if outputs:
# If caller requested specific outputs, we can prune any
# unrelated nodes further up the dag.
ending_in_outputs = set()
for input_name in outputs:
ending_in_outputs.update(nx.ancestors(dag, input_name))
broken_dag = broken_dag.subgraph(ending_in_outputs | set(outputs))


# Prune (un-satifiable) operations with partial inputs.
# See yahoo/graphkit#18
#
unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs)
necessary_nodes -= set(unsatisfiables)
unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs)
shrinked_dag = dag.subgraph(broken_dag.nodes - unsatisfied)

shrinked_dag = graph.subgraph(necessary_nodes)
plan = self._build_execution_plan(shrinked_dag)

return shrinked_dag
return plan


def compile(self, outputs=(), inputs=()):
"""
Solve dag, set the :attr:`execution_plan` and cache it.
Solve dag, set the :attr:`execution_plan`, and cache it.
See :meth:`_solve_dag()` for description
See :meth:`_solve_dag()` for detailed description.
:param iterable outputs:
A list of desired output names. This can also be ``None``, in which
case the necessary steps are all graph nodes that are reachable
from one of the provided inputs.
:param dict inputs:
The inputs names of all given inputs.
The input names of all given inputs.
"""

# return steps if it has already been computed before for this set of inputs and outputs
Expand All @@ -317,8 +325,7 @@ def compile(self, outputs=(), inputs=()):
if cache_key in self._cached_execution_plans:
self.execution_plan = self._cached_execution_plans[cache_key]
else:
dag = self._solve_dag(outputs, inputs)
plan = self._build_execution_plan(dag)
plan = self._solve_dag(outputs, inputs)
# save this result in a precomputed cache for future lookup
self.execution_plan = self._cached_execution_plans[cache_key] = plan

Expand All @@ -338,16 +345,21 @@ def compute(self, outputs, named_inputs, method=None):
and the values are the concrete values you
want to set for the data node.
:param method:
if ``"parallel"``, launches multi-threading.
Set when invoking a composed graph or by
:meth:`~NetworkOperation.set_execution_method()`.
:returns: a dictionary of output data objects, keyed by name.
"""

assert isinstance(outputs, (list, tuple)) or outputs is None,\
"The outputs argument must be a list"

# start with fresh data cache
cache = {}
cache.update(named_inputs)
# start with fresh data cache & overwrites
cache = named_inputs.copy()

# Build and set :attr:`execution_plan`.
self.compile(outputs, named_inputs.keys())

# choose a method of execution
Expand Down
5 changes: 3 additions & 2 deletions test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_pruning_multiouts_not_override_intermediates2():
operation(name="op2", needs=["d", "e"], provides=["asked"])(mul),
)

exp = {"a": 5, "overriden": 1, "c": 2, "asked": 3}
exp = {"a": 5, "overriden": 1, "c": 2, "d": 3, "e": 10, "asked": 30}
# FAILs
# - on v1.2.4 with (overriden, asked) = (5, 70) instead of (1, 13)
# - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4.
Expand All @@ -265,12 +265,13 @@ def test_pruning_with_given_intermediate_and_asked_out():
operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add),
)

exp = {"given-1": 5, "b": 2, "given-2": 7, "a": 5, "asked": 12}
exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7}
# v1.2.4 is ok
assert netop({"given-1": 5, "b": 2, "given-2": 2}) == exp
# FAILS
# - on v1.2.4 with KeyError: 'a',
# - on #18 (unsatisfied) with no result.
# FIXED on #18+#26 (new dag solver).
assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")


Expand Down

0 comments on commit 0830b7c

Please sign in to comment.