From 1aecf4d23bb35a5505da8b6b3149287253b79a32 Mon Sep 17 00:00:00 2001 From: Terry Stewart Date: Wed, 25 Jan 2017 21:18:20 -0500 Subject: [PATCH] Generate seeds before adding graph components This means that if you do set the seed for the top-level Network in your model, adding or removing graphs will not affect the seed for the Ensembles in your model. Fixes #855 --- nengo_gui/page.py | 11 +++++++++++ nengo_gui/seed_generation.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 nengo_gui/seed_generation.py diff --git a/nengo_gui/page.py b/nengo_gui/page.py index 46372269..1030b090 100644 --- a/nengo_gui/page.py +++ b/nengo_gui/page.py @@ -12,6 +12,7 @@ import nengo_gui import nengo_gui.user_action import nengo_gui.config +import nengo_gui.seed_generation class PageSettings(object): @@ -434,6 +435,12 @@ def build(self): with self.lock: self.building = True + # set all the seeds so that creating components doesn't affect + # the neural model itself + seeds = nengo_gui.seed_generation.define_all_seeds(self.model) + for obj, s in seeds.items(): + obj.seed = s + # modify the model for the various Components for c in self.components: c.add_nengo_objects(self) @@ -456,6 +463,10 @@ def build(self): line = nengo_gui.exec_env.determine_line_number() self.error = dict(trace=traceback.format_exc(), line=line) + # set the defined seeds back to None + for obj in seeds: + obj.seed = None + self.stdout += exec_env.stdout.getvalue() if self.sim is not None: diff --git a/nengo_gui/seed_generation.py b/nengo_gui/seed_generation.py new file mode 100644 index 00000000..d6d597dd --- /dev/null +++ b/nengo_gui/seed_generation.py @@ -0,0 +1,32 @@ +import nengo +import nengo.utils.numpy as npext +import numpy as np + + +def define_all_seeds(net, seeds=None): + if seeds is None: + seeds = {} + + if net.seed is None: + if net not in seeds: + # this only happens at the very top level + seeds[net] = np.random.randint(npext.maxint) + rng = np.random.RandomState(seed=seeds[net]) + else: + rng = np.random.RandomState(seed=net.seed) + + # let's use the same algorithm as the builder, just to be consistent + sorted_types = sorted(net.objects, key=lambda t: t.__name__) + for obj_type in sorted_types: + for obj in net.objects[obj_type]: + # generate a seed for each item, so that manually setting a seed + # for a particular item doesn't change the generated seed for + # other items + generated_seed = rng.randint(npext.maxint) + if obj.seed is None: + seeds[obj] = generated_seed + + for subnet in net.networks: + define_all_seeds(subnet, seeds) + + return seeds