Skip to content

Commit

Permalink
Generate seeds before adding graph components
Browse files Browse the repository at this point in the history
This means that if you do set the seed for the top-level
Network in your model, adding or removing graphs will not
affect the seed for the Ensembles in your model.

Fixes #855
  • Loading branch information
tcstewar committed Feb 17, 2017
1 parent 55e730b commit 1aecf4d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
11 changes: 11 additions & 0 deletions nengo_gui/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import nengo_gui
import nengo_gui.user_action
import nengo_gui.config
import nengo_gui.seed_generation


class PageSettings(object):
Expand Down Expand Up @@ -434,6 +435,12 @@ def build(self):
with self.lock:
self.building = True

# set all the seeds so that creating components doesn't affect
# the neural model itself
seeds = nengo_gui.seed_generation.define_all_seeds(self.model)
for obj, s in seeds.items():
obj.seed = s

# modify the model for the various Components
for c in self.components:
c.add_nengo_objects(self)
Expand All @@ -456,6 +463,10 @@ def build(self):
line = nengo_gui.exec_env.determine_line_number()
self.error = dict(trace=traceback.format_exc(), line=line)

# set the defined seeds back to None
for obj in seeds:
obj.seed = None

self.stdout += exec_env.stdout.getvalue()

if self.sim is not None:
Expand Down
32 changes: 32 additions & 0 deletions nengo_gui/seed_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import nengo
import nengo.utils.numpy as npext
import numpy as np


def define_all_seeds(net, seeds=None):
if seeds is None:
seeds = {}

if net.seed is None:
if net not in seeds:
# this only happens at the very top level
seeds[net] = np.random.randint(npext.maxint)
rng = np.random.RandomState(seed=seeds[net])
else:
rng = np.random.RandomState(seed=net.seed)

# let's use the same algorithm as the builder, just to be consistent
sorted_types = sorted(net.objects, key=lambda t: t.__name__)
for obj_type in sorted_types:
for obj in net.objects[obj_type]:
# generate a seed for each item, so that manually setting a seed
# for a particular item doesn't change the generated seed for
# other items
generated_seed = rng.randint(npext.maxint)
if obj.seed is None:
seeds[obj] = generated_seed

for subnet in net.networks:
define_all_seeds(subnet, seeds)

return seeds

0 comments on commit 1aecf4d

Please sign in to comment.