Skip to content

Commit

Permalink
ENH: use Generator instead of RandomState
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfaff committed Oct 28, 2024
1 parent 4773226 commit e2536a9
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 62 deletions.
25 changes: 17 additions & 8 deletions src/emcee/backends/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
from tempfile import NamedTemporaryFile
import json

import numpy as np

Expand All @@ -19,6 +20,13 @@
h5py = None


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)


def does_hdf5_support_longdouble():
if h5py is None:
return False
Expand Down Expand Up @@ -193,12 +201,11 @@ def accepted(self):
@property
def random_state(self):
with self.open() as f:
elements = [
v
for k, v in sorted(f[self.name].attrs.items())
if k.startswith("random_state_")
]
return elements if len(elements) else None
try:
dct = json.loads(f[self.name].attrs['random_state'])
except KeyError:
return None
return dct

def grow(self, ngrow, blobs):
"""Expand the storage space by some number of samples
Expand Down Expand Up @@ -261,8 +268,10 @@ def save_step(self, state, accepted):
g["blobs"][iteration, :] = state.blobs
g["accepted"][:] += accepted

for i, v in enumerate(state.random_state):
g.attrs["random_state_{0}".format(i)] = v
g.attrs["random_state"] = json.dumps(
state.random_state,
cls=NumpyEncoder
)

g.attrs["iteration"] = iteration + 1

Expand Down
39 changes: 25 additions & 14 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
vectorize=False,
blobs_dtype=None,
parameter_names: Optional[Union[Dict[str, int], List[str]]] = None,
rng = None,
# Deprecated...
a=None,
postargs=None,
Expand Down Expand Up @@ -136,11 +137,14 @@ def __init__(
self.nwalkers = nwalkers
self.backend = Backend() if backend is None else backend

# This is a random number generator that we can easily set the state
# of
self._random = np.random.default_rng(rng)

# Deal with re-used backends
if not self.backend.initialized:
self._previous_state = None
self.reset()
state = np.random.get_state()
else:
# Check the backend shape
if self.backend.shape != (self.nwalkers, self.ndim):
Expand All @@ -153,19 +157,14 @@ def __init__(

# Get the last random state
state = self.backend.random_state
if state is None:
state = np.random.get_state()
if state is not None:
self._random.bit_generator.state = state

# Grab the last step so that we can restart
it = self.backend.iteration
if it > 0:
self._previous_state = self.get_last_sample()

# This is a random number generator that we can easily set the state
# of without affecting the numpy-wide generator
self._random = np.random.mtrand.RandomState()
self._random.set_state(state)

# Do a little bit of _magic_ to make the likelihood call with
# ``args`` and ``kwargs`` pickleable.
self.log_prob_fn = _FunctionWrapper(log_prob_fn, args, kwargs)
Expand Down Expand Up @@ -216,14 +215,18 @@ def __init__(
@property
def random_state(self):
"""
The state of the internal random number generator. In practice, it's
the result of calling ``get_state()`` on a
``numpy.random.mtrand.RandomState`` object. You can try to set this
The state of the internal random number generator. You can try to set this
property but be warned that if you do this and it fails, it will do
so silently.
"""
return self._random.get_state()
def rng_dict(rng):
bg_state = rng.bit_generator.state
ss = rng.bit_generator.seed_seq
ss_dict = dict(entropy=ss.entropy, spawn_key=ss.spawn_key, pool_size=ss.pool_size, n_children_spawned=ss.n_children_spawned)
return dict(bg_state=bg_state, seed_seq=ss_dict)
return rng_dict(self._random)
# return self._random.bit_generator.state

@random_state.setter # NOQA
def random_state(self, state):
Expand All @@ -232,8 +235,16 @@ def random_state(self, state):
if it doesn't work. Don't say I didn't warn you...
"""
def _rng_fromdict(d):
bg_state = d['bg_state']
ss = np.random.SeedSequence(**d['seed_seq'])
bg = getattr(np.random, bg_state['bit_generator'])(ss)
bg.state = bg_state
rng = np.random.Generator(bg)
return rng
try:
self._random.set_state(state)
self._random = _rng_fromdict(state)
# self._random.bit_generator = state
except:
pass

Expand Down Expand Up @@ -325,7 +336,7 @@ def sample(
# Try to set the initial value of the random number generator. This
# fails silently if it doesn't work but that's what we want because
# we'll just interpret any garbage as letting the generator stay in
# it's current state.
# its current state.
if rstate0 is not None:
deprecation_warning(
"The 'rstate0' argument is deprecated, use a 'State' "
Expand Down
2 changes: 1 addition & 1 deletion src/emcee/moves/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_proposal(self, s, c, random):
diffs = np.diff(c[pairs], axis=1).squeeze(axis=1) # (ns, ndim)

# Sample a gamma value for each walker following Nelson et al. (2013)
gamma = self.g0 * (1 + self.sigma * random.randn(ns, 1)) # (ns, 1)
gamma = self.g0 * (1 + self.sigma * random.standard_normal((ns, 1))) # (ns, 1)

# In this way, sigma is the standard deviation of the distribution of gamma,
# instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006).
Expand Down
2 changes: 1 addition & 1 deletion src/emcee/moves/de_snooker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_proposal(self, s, c, random):
q = np.empty_like(s)
metropolis = np.empty(Ns, dtype=np.float64)
for i in range(Ns):
w = np.array([c[j][random.randint(Nc[j])] for j in range(3)])
w = np.array([c[j][random.integers(Nc[j])] for j in range(3)])
random.shuffle(w)
z, z1, z2 = w
delta = s[i] - z
Expand Down
6 changes: 3 additions & 3 deletions src/emcee/moves/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ def get_factor(self, rng):
return np.exp(rng.uniform(-self._log_factor, self._log_factor))

def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape))

def __call__(self, x0, rng):
nw, nd = x0.shape
xnew = self.get_updated_vector(rng, x0)
if self.mode == "random":
m = (range(nw), rng.randint(x0.shape[-1], size=nw))
m = (range(nw), rng.integers(x0.shape[-1], size=nw))
elif self.mode == "sequential":
m = (range(nw), self.index % nd + np.zeros(nw, dtype=int))
self.index = (self.index + 1) % nd
Expand All @@ -106,7 +106,7 @@ def __call__(self, x0, rng):

class _diagonal_proposal(_isotropic_proposal):
def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape))


class _proposal(_isotropic_proposal):
Expand Down
2 changes: 1 addition & 1 deletion src/emcee/moves/mh.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def propose(self, model, state):

# Loop over the walkers and update them accordingly.
lnpdiff = new_log_probs - state.log_prob + factors
accepted = np.log(model.random.rand(nwalkers)) < lnpdiff
accepted = np.log(model.random.random(nwalkers)) < lnpdiff

# Update the parameters
new_state = State(q, log_prob=new_log_probs, blobs=new_blobs)
Expand Down
2 changes: 1 addition & 1 deletion src/emcee/moves/red_blue.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def propose(self, model, state):
zip(all_inds[S1], factors, new_log_probs)
):
lnpdiff = f + nlp - state.log_prob[j]
if lnpdiff > np.log(model.random.rand()):
if lnpdiff > np.log(model.random.random()):
accepted[j] = True

new_state = State(q, log_prob=new_log_probs, blobs=new_blobs)
Expand Down
4 changes: 2 additions & 2 deletions src/emcee/moves/stretch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_proposal(self, s, c, random):
c = np.concatenate(c, axis=0)
Ns, Nc = len(s), len(c)
ndim = s.shape[1]
zz = ((self.a - 1.0) * random.rand(Ns) + 1) ** 2.0 / self.a
zz = ((self.a - 1.0) * random.random(Ns) + 1) ** 2.0 / self.a
factors = (ndim - 1.0) * np.log(zz)
rint = random.randint(Nc, size=(Ns,))
rint = random.integers(Nc, size=(Ns,))
return c[rint] - (c[rint] - s) * zz[:, None], factors
35 changes: 10 additions & 25 deletions src/emcee/tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,10 @@ def run_sampler(
):
if lp is None:
lp = normal_log_prob_blobs if blobs else normal_log_prob
if seed is not None:
np.random.seed(seed)
coords = np.random.randn(nwalkers, ndim)
rng = np.random.default_rng(seed)
coords = rng.standard_normal((nwalkers, ndim))
sampler = EnsembleSampler(
nwalkers, ndim, lp, backend=backend, blobs_dtype=dtype
nwalkers, ndim, lp, rng=rng, backend=backend, blobs_dtype=dtype
)
sampler.run_mcmc(coords, nsteps, thin_by=thin_by)
return sampler
Expand Down Expand Up @@ -125,10 +124,7 @@ def test_backend(backend, dtype, blobs):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert all(
np.allclose(l1, l2)
for l1, l2 in zip(last1.random_state[1:], last2.random_state[1:])
)
assert last1.random_state == last2.random_state
if blobs:
_custom_allclose(last1.blobs, last2.blobs)
else:
Expand All @@ -141,12 +137,11 @@ def test_backend(backend, dtype, blobs):

@pytest.mark.parametrize("backend,dtype", product(other_backends, dtypes))
def test_reload(backend, dtype):
with backend() as backend1:
with (backend() as backend1):
run_sampler(backend1, dtype=dtype)

# Test the state
state = backend1.random_state
np.random.set_state(state)

# Load the file using a new backend object.
backend2 = backends.HDFBackend(
Expand All @@ -156,11 +151,7 @@ def test_reload(backend, dtype):
with pytest.raises(RuntimeError):
backend2.reset(32, 3)

assert state[0] == backend2.random_state[0]
assert all(
np.allclose(a, b)
for a, b in zip(state[1:], backend2.random_state[1:])
)
assert state == backend2.random_state

# Check all of the components.
for k in ["chain", "log_prob", "blobs"]:
Expand All @@ -172,10 +163,7 @@ def test_reload(backend, dtype):
last2 = backend2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert all(
np.allclose(l1, l2)
for l1, l2 in zip(last1.random_state[1:], last2.random_state[1:])
)
assert last1.random_state == last2.random_state
_custom_allclose(last1.blobs, last2.blobs)

a = backend1.accepted
Expand All @@ -188,11 +176,11 @@ def test_restart(backend, dtype):
# Run a sampler with the default backend.
b = backends.Backend()
run_sampler(b, dtype=dtype)
sampler1 = run_sampler(b, seed=None, dtype=dtype)
sampler1 = run_sampler(b, seed=2, dtype=dtype)

with backend() as be:
run_sampler(be, dtype=dtype)
sampler2 = run_sampler(be, seed=None, dtype=dtype)
sampler2 = run_sampler(be, seed=2, dtype=dtype)

# Check all of the components.
for k in ["chain", "log_prob", "blobs"]:
Expand All @@ -204,10 +192,7 @@ def test_restart(backend, dtype):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert all(
np.allclose(l1, l2)
for l1, l2 in zip(last1.random_state[1:], last2.random_state[1:])
)
assert last1.random_state == last2.random_state
_custom_allclose(last1.blobs, last2.blobs)

a = sampler1.acceptance_fraction
Expand Down
6 changes: 3 additions & 3 deletions src/emcee/tests/unit/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def run_sampler(
progress=False,
store=True,
):
np.random.seed(seed)
coords = np.random.randn(nwalkers, ndim)
sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, backend=backend)
rng = np.random.default_rng(seed)
coords = rng.standard_normal((nwalkers, ndim))
sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, rng=rng, backend=backend)
sampler.run_mcmc(
coords,
nsteps,
Expand Down
6 changes: 3 additions & 3 deletions src/emcee/tests/unit/test_stretch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def test_live_dangerously(nwalkers=32, nsteps=3000, seed=1234):
warnings.filterwarnings("error")

# Set up the random number generator.
np.random.seed(seed)
rng = np.random.default_rng(seed)
state = State(
np.random.randn(nwalkers, 2 * nwalkers),
rng.standard_normal((nwalkers, 2 * nwalkers)),
log_prob=np.random.randn(nwalkers),
)
model = Model(None, lambda x: (np.zeros(len(x)), None), map, np.random)
model = Model(None, lambda x: (np.zeros(len(x)), None), map, rng)
proposal = moves.StretchMove()

# Test to make sure that the error is thrown if there aren't enough
Expand Down

0 comments on commit e2536a9

Please sign in to comment.