Skip to content

Commit

Permalink
initial state setting
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfaff committed Oct 28, 2024
1 parent 2a0b394 commit c5c5842
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
# Get the last random state
state = self.backend.random_state
if state is not None:
self._random.bit_generator.state = state
self.random_state = state

# Grab the last step so that we can restart
it = self.backend.iteration
Expand Down Expand Up @@ -225,7 +225,12 @@ def random_state(self):
def rng_dict(rng):
bg_state = rng.bit_generator.state
ss = rng.bit_generator.seed_seq
ss_dict = dict(entropy=ss.entropy, spawn_key=ss.spawn_key, pool_size=ss.pool_size, n_children_spawned=ss.n_children_spawned)
ss_dict = dict(
entropy=ss.entropy,
spawn_key=ss.spawn_key,
pool_size=ss.pool_size,
n_children_spawned=ss.n_children_spawned
)
return dict(bg_state=bg_state, seed_seq=ss_dict)
return rng_dict(self._random)
# return self._random.bit_generator.state
Expand Down
4 changes: 2 additions & 2 deletions src/emcee/tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_backend(backend, dtype, blobs):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert last1.random_state == last2.random_state
assert last1.random_state['bg_state'] == last2.random_state['bg_state']
if blobs:
_custom_allclose(last1.blobs, last2.blobs)
else:
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_restart(backend, dtype):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert last1.random_state == last2.random_state
assert last1.random_state['bg_state'] == last2.random_state['bg_state']
_custom_allclose(last1.blobs, last2.blobs)

a = sampler1.acceptance_fraction
Expand Down

0 comments on commit c5c5842

Please sign in to comment.