-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
432 lines (368 loc) · 13.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
import collections
import itertools
import networkx as nx
from graph_nets import graphs
from graph_nets import utils_np
from graph_nets import utils_tf
from graph_nets.demos import models
import numpy as np
import tensorflow as tf
DISTANCE_WEIGHT_NAME = "weight"
def dict_to_graph(d):
graph = nx.Graph()
graph.add_nodes_from(d['nodes']);
for ind, e in enumerate(d['edges']):
i = d['senders'][ind]
j = d['receivers'][ind]
w = e[0]
graph.add_edge(i, j, weight=w )
return graph
def pairwise(iterable):
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)
def set_diff(seq0, seq1):
return list(set(seq0) - set(seq1))
def to_one_hot(indices, max_value, axis=-1):
one_hot = np.eye(max_value)[indices]
if axis not in (-1, one_hot.ndim):
one_hot = np.moveaxis(one_hot, -1, axis)
return one_hot
def get_node_dict(graph, attr):
return {k: v[attr] for k, v in graph.node.items()}
def add_shortest_path(rand, graph, min_length=1):
"""Samples a shortest path from A to B and adds attributes to indicate it.
Args:
rand: A random seed for the graph generator. Default= None.
graph: A `nx.Graph`.
min_length: (optional) An `int` minimum number of edges in the shortest
path. Default= 1.
Returns:
The `nx.DiGraph` with the shortest path added.
Raises:
ValueError: All shortest paths are below the minimum length
"""
# Map from node pairs to the length of their shortest path.
pair_to_length_dict = {}
try:
# This is for compatibility with older networkx.
lengths = nx.all_pairs_shortest_path_length(graph).items()
except AttributeError:
# This is for compatibility with newer networkx.
lengths = list(nx.all_pairs_shortest_path_length(graph))
for x, yy in lengths:
for y, l in yy.items():
if l >= min_length:
pair_to_length_dict[x, y] = l
if max(pair_to_length_dict.values()) < min_length:
raise ValueError("All shortest paths are below the minimum length")
# The node pairs which exceed the minimum length.
node_pairs = list(pair_to_length_dict)
# Computes probabilities per pair, to enforce uniform sampling of each
# shortest path lengths.
# The counts of pairs per length.
counts = collections.Counter(pair_to_length_dict.values())
prob_per_length = 1.0 / len(counts)
probabilities = [
prob_per_length / counts[pair_to_length_dict[x]] for x in node_pairs
]
# Choose the start and end points.
i = rand.choice(len(node_pairs), p=probabilities)
start, end = node_pairs[i]
path = nx.shortest_path(
graph, source=start, target=end, weight=DISTANCE_WEIGHT_NAME)
# Creates a directed graph, to store the directed path from start to end.
digraph = graph.to_directed()
# Add the "start", "end", and "solution" attributes to the nodes and edges.
digraph.add_node(start, start=True)
digraph.add_node(end, end=True)
digraph.add_nodes_from(set_diff(digraph.nodes(), [start]), start=False)
digraph.add_nodes_from(set_diff(digraph.nodes(), [end]), end=False)
digraph.add_nodes_from(set_diff(digraph.nodes(), path), solution=False)
digraph.add_nodes_from(path, solution=True)
path_edges = list(pairwise(path))
digraph.add_edges_from(set_diff(digraph.edges(), path_edges), solution=False)
digraph.add_edges_from(path_edges, solution=True)
return digraph
def graph_to_input_target(graph):
"""Returns 2 graphs with input and target feature vectors for training.
Args:
graph: An `nx.DiGraph` instance.
Returns:
The input `nx.DiGraph` instance.
The target `nx.DiGraph` instance.
Raises:
ValueError: unknown node type
"""
def create_feature(attr, fields):
return np.hstack([np.array(attr[field], dtype=float) for field in fields])
input_node_fields = ("pos", "start", "end")
input_edge_fields = ("distance",)
target_node_fields = ("solution",)
target_edge_fields = ("solution",)
input_graph = graph.copy()
target_graph = graph.copy()
solution_length = 0
for node_index, node_feature in graph.nodes(data=True):
input_graph.add_node(
node_index, features=create_feature(node_feature, input_node_fields))
target_node = to_one_hot(
create_feature(node_feature, target_node_fields).astype(int), 2)[0]
target_graph.add_node(node_index, features=target_node)
solution_length += int(node_feature["solution"])
solution_length /= graph.number_of_nodes()
for receiver, sender, features in graph.edges(data=True):
input_graph.add_edge(
sender, receiver, features=create_feature(features, input_edge_fields))
target_edge = to_one_hot(
create_feature(features, target_edge_fields).astype(int), 2)[0]
target_graph.add_edge(sender, receiver, features=target_edge)
input_graph.graph["features"] = np.array([0.0])
target_graph.graph["features"] = np.array([solution_length], dtype=float)
return input_graph, target_graph
def create_placeholders(input_graph, target_graph):
input_ph = utils_tf.placeholders_from_networkxs([input_graph])
target_ph = utils_tf.placeholders_from_networkxs([target_graph])
return input_ph, target_ph
def compute_accuracy(target, output, use_nodes=True, use_edges=False):
"""Calculate model accuracy.
Returns the number of correctly predicted shortest path nodes and the number
of completely solved graphs (100% correct predictions).
Args:
target: A `graphs.GraphsTuple` that contains the target graph.
output: A `graphs.GraphsTuple` that contains the output graph.
use_nodes: A `bool` indicator of whether to compute node accuracy or not.
use_edges: A `bool` indicator of whether to compute edge accuracy or not.
Returns:
correct: A `float` fraction of correctly labeled nodes/edges.
solved: A `float` fraction of graphs that are completely correctly labeled.
Raises:
ValueError: Nodes or edges (or both) must be used
"""
if not use_nodes and not use_edges:
raise ValueError("Nodes or edges (or both) must be used")
tdds = utils_np.graphs_tuple_to_data_dicts(target)
odds = utils_np.graphs_tuple_to_data_dicts(output)
cs = []
ss = []
for td, od in zip(tdds, odds):
xn = np.argmax(td["nodes"], axis=-1)
yn = np.argmax(od["nodes"], axis=-1)
xe = np.argmax(td["edges"], axis=-1)
ye = np.argmax(od["edges"], axis=-1)
c = []
if use_nodes:
c.append(xn == yn)
if use_edges:
c.append(xe == ye)
c = np.concatenate(c, axis=0)
s = np.all(c)
cs.append(c)
ss.append(s)
correct = np.mean(np.concatenate(cs, axis=0))
solved = np.mean(np.stack(ss))
return correct, solved
def create_loss_ops(target_op, output_ops):
loss_ops = [
tf.losses.softmax_cross_entropy(target_op.nodes, output_op.nodes) +
tf.losses.softmax_cross_entropy(target_op.edges, output_op.edges)
for output_op in output_ops
]
return loss_ops
def make_all_runnable_in_session(*args):
"""Lets an iterable of TF graphs be output from a session as NP graphs."""
return [utils_tf.make_runnable_in_session(a) for a in args]
class GraphPlotter(object):
def __init__(self, ax, graph, pos):
self._ax = ax
self._graph = graph
self._pos = pos
self._base_draw_kwargs = dict(G=self._graph, pos=self._pos, ax=self._ax)
self._solution_length = None
self._nodes = None
self._edges = None
self._start_nodes = None
self._end_nodes = None
self._solution_nodes = None
self._intermediate_solution_nodes = None
self._solution_edges = None
self._non_solution_nodes = None
self._non_solution_edges = None
self._ax.set_axis_off()
@property
def solution_length(self):
if self._solution_length is None:
self._solution_length = len(self._solution_edges)
return self._solution_length
@property
def nodes(self):
if self._nodes is None:
self._nodes = self._graph.nodes()
return self._nodes
@property
def edges(self):
if self._edges is None:
self._edges = self._graph.edges()
return self._edges
@property
def start_nodes(self):
if self._start_nodes is None:
self._start_nodes = [
n for n in self.nodes if self._graph.node[n].get("start", False)
]
return self._start_nodes
@property
def end_nodes(self):
if self._end_nodes is None:
self._end_nodes = [
n for n in self.nodes if self._graph.node[n].get("end", False)
]
return self._end_nodes
@property
def solution_nodes(self):
if self._solution_nodes is None:
self._solution_nodes = [
n for n in self.nodes if self._graph.node[n].get("solution", False)
]
return self._solution_nodes
@property
def intermediate_solution_nodes(self):
if self._intermediate_solution_nodes is None:
self._intermediate_solution_nodes = [
n for n in self.nodes
if self._graph.node[n].get("solution", False) and
not self._graph.node[n].get("start", False) and
not self._graph.node[n].get("end", False)
]
return self._intermediate_solution_nodes
@property
def solution_edges(self):
if self._solution_edges is None:
self._solution_edges = [
e for e in self.edges
if self._graph.get_edge_data(e[0], e[1]).get("solution", False)
]
return self._solution_edges
@property
def non_solution_nodes(self):
if self._non_solution_nodes is None:
self._non_solution_nodes = [
n for n in self.nodes
if not self._graph.node[n].get("solution", False)
]
return self._non_solution_nodes
@property
def non_solution_edges(self):
if self._non_solution_edges is None:
self._non_solution_edges = [
e for e in self.edges
if not self._graph.get_edge_data(e[0], e[1]).get("solution", False)
]
return self._non_solution_edges
def _make_draw_kwargs(self, **kwargs):
kwargs.update(self._base_draw_kwargs)
return kwargs
def _draw(self, draw_function, zorder=None, **kwargs):
draw_kwargs = self._make_draw_kwargs(**kwargs)
collection = draw_function(**draw_kwargs)
if collection is not None and zorder is not None:
try:
# This is for compatibility with older matplotlib.
collection.set_zorder(zorder)
except AttributeError:
# This is for compatibility with newer matplotlib.
collection[0].set_zorder(zorder)
return collection
def draw_nodes(self, **kwargs):
"""Useful kwargs: nodelist, node_size, node_color, linewidths."""
if ("node_color" in kwargs and
isinstance(kwargs["node_color"], collections.Sequence) and
len(kwargs["node_color"]) in {3, 4} and
not isinstance(kwargs["node_color"][0],
(collections.Sequence, np.ndarray))):
num_nodes = len(kwargs.get("nodelist", self.nodes))
kwargs["node_color"] = np.tile(
np.array(kwargs["node_color"])[None], [num_nodes, 1])
return self._draw(nx.draw_networkx_nodes, **kwargs)
def draw_edges(self, **kwargs):
"""Useful kwargs: edgelist, width."""
return self._draw(nx.draw_networkx_edges, **kwargs)
def draw_graph(self,
node_size=200,
node_color=(0.4, 0.8, 0.4),
node_linewidth=1.0,
edge_width=1.0):
# Plot nodes.
self.draw_nodes(
nodelist=self.nodes,
node_size=node_size,
node_color=node_color,
linewidths=node_linewidth,
zorder=20)
# Plot edges.
self.draw_edges(edgelist=self.edges, width=edge_width, zorder=10)
def draw_graph_with_solution(self,
node_size=200,
node_color=(0.4, 0.8, 0.4),
node_linewidth=1.0,
edge_width=1.0,
start_color="w",
end_color="k",
solution_node_linewidth=3.0,
solution_edge_width=3.0):
node_border_color = (0.0, 0.0, 0.0, 1.0)
node_collections = {}
# Plot start nodes.
node_collections["start nodes"] = self.draw_nodes(
nodelist=self.start_nodes,
node_size=node_size,
node_color=start_color,
linewidths=solution_node_linewidth,
edgecolors=node_border_color,
zorder=100)
# Plot end nodes.
node_collections["end nodes"] = self.draw_nodes(
nodelist=self.end_nodes,
node_size=node_size,
node_color=end_color,
linewidths=solution_node_linewidth,
edgecolors=node_border_color,
zorder=90)
# Plot intermediate solution nodes.
if isinstance(node_color, dict):
c = [node_color[n] for n in self.intermediate_solution_nodes]
else:
c = node_color
node_collections["intermediate solution nodes"] = self.draw_nodes(
nodelist=self.intermediate_solution_nodes,
node_size=node_size,
node_color=c,
linewidths=solution_node_linewidth,
edgecolors=node_border_color,
zorder=80)
# Plot solution edges.
node_collections["solution edges"] = self.draw_edges(
edgelist=self.solution_edges, width=solution_edge_width, zorder=70)
# Plot non-solution nodes.
if isinstance(node_color, dict):
c = [node_color[n] for n in self.non_solution_nodes]
else:
c = node_color
node_collections["non-solution nodes"] = self.draw_nodes(
nodelist=self.non_solution_nodes,
node_size=node_size,
node_color=c,
linewidths=node_linewidth,
edgecolors=node_border_color,
zorder=20)
# Plot non-solution edges.
node_collections["non-solution edges"] = self.draw_edges(
edgelist=self.non_solution_edges, width=edge_width, zorder=10)
# Set title as solution length.
self._ax.set_title("Solution length: {}".format(self.solution_length))
return node_collections
def get_edges_from_list(list):
edges = []
for e in list:
edges.append([e])
return edges