Skip to content

Commit

Permalink
Set rng state for trace fn mapping draws to posterior samples
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <[email protected]>
  • Loading branch information
lucianopaz and ricardoV94 committed Jan 10, 2025
1 parent 09def3b commit 312effd
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 12 deletions.
11 changes: 9 additions & 2 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from pymc.blocking import PointType
from pymc.model import Model
from pymc.step_methods.compound import BlockedStep, CompoundStep
from pymc.util import get_random_generator

HAS_MCB = False
try:
Expand Down Expand Up @@ -103,11 +104,13 @@ def _init_trace(
model: Model,
trace_vars: list[TensorVariable] | None = None,
initial_point: PointType | None = None,
rng: np.random.Generator | None = None,
) -> BaseTrace:
"""Initialize a trace backend for a chain."""
rng_ = get_random_generator(rng)
strace: BaseTrace
if trace is None:
strace = NDArray(model=model, vars=trace_vars, test_point=initial_point)
strace = NDArray(model=model, vars=trace_vars, test_point=initial_point, rng=rng_)
elif isinstance(trace, BaseTrace):
if len(trace) > 0:
raise ValueError("Continuation of traces is no longer supported.")
Expand All @@ -129,6 +132,7 @@ def init_traces(
model: Model,
trace_vars: list[TensorVariable] | None = None,
tune: int = 0,
rng: np.random.Generator | None = None,
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
"""Initialize a trace recorder for each chain."""
if isinstance(backend, ZarrTrace):
Expand All @@ -140,6 +144,7 @@ def init_traces(
model=model,
vars=trace_vars,
test_point=initial_point,
rng=rng,
)
return None, backend.straces
if HAS_MCB and isinstance(backend, Backend):
Expand All @@ -149,6 +154,7 @@ def init_traces(
initial_point=initial_point,
step=step,
model=model,
rng=rng,
)

assert backend is None or isinstance(backend, BaseTrace)
Expand All @@ -161,7 +167,8 @@ def init_traces(
model=model,
trace_vars=trace_vars,
initial_point=initial_point,
rng=rng_,
)
for chain_number in range(chains)
for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains))
]
return None, traces
5 changes: 4 additions & 1 deletion pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from pymc.backends.report import SamplerReport
from pymc.model import modelcontext
from pymc.pytensorf import compile
from pymc.pytensorf import compile, set_function_rngs
from pymc.util import get_var_name

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -159,6 +159,7 @@ def __init__(
fn=None,
var_shapes=None,
var_dtypes=None,
rng=None,
):
model = modelcontext(model)

Expand All @@ -177,6 +178,8 @@ def __init__(
on_unused_input="ignore",
)
fn.trust_input = True
if rng is not None:
fn = set_function_rngs(fn=fn, rng=rng)

# Get variable shapes. Most backends will need this
# information.
Expand Down
16 changes: 13 additions & 3 deletions pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from pymc.backends.base import IBaseTrace
from pymc.model import Model
from pymc.pytensorf import PointFunc
from pymc.pytensorf import PointFunc, set_function_rngs
from pymc.step_methods.compound import (
BlockedStep,
CompoundStep,
Expand All @@ -38,6 +38,7 @@
flat_statname,
flatten_steps,
)
from pymc.util import get_random_generator

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,7 +97,11 @@ class ChainRecordAdapter(IBaseTrace):
"""Wraps an McBackend ``Chain`` as an ``IBaseTrace``."""

def __init__(
self, chain: mcb.Chain, point_fn: PointFunc, stats_bijection: StatsBijection
self,
chain: mcb.Chain,
point_fn: PointFunc,
stats_bijection: StatsBijection,
rng: np.random.Generator | None = None,
) -> None:
# Assign attributes required by IBaseTrace
self.chain = chain.cmeta.chain_number
Expand All @@ -107,8 +112,11 @@ def __init__(
for sstats in stats_bijection._stat_groups
]

self._rng = rng
self._chain = chain
self._point_fn = point_fn
if rng is not None:
self._point_fn = set_function_rngs(self._point_fn, rng)
self._statsbj = stats_bijection
super().__init__()

Expand Down Expand Up @@ -257,6 +265,7 @@ def init_chain_adapters(
initial_point: Mapping[str, np.ndarray],
step: CompoundStep | BlockedStep,
model: Model,
rng: np.random.Generator | None,
) -> tuple[mcb.Run, list[ChainRecordAdapter]]:
"""Create an McBackend metadata description for the MCMC run.
Expand Down Expand Up @@ -286,7 +295,8 @@ def init_chain_adapters(
chain=run.init_chain(chain_number=chain_number),
point_fn=point_fn,
stats_bijection=statsbj,
rng=rng_,
)
for chain_number in range(chains)
for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains))
]
return run, adapters
20 changes: 17 additions & 3 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,20 @@
from pymc.backends.base import BaseTrace
from pymc.blocking import StatDtype, StatShape
from pymc.model.core import Model, modelcontext
from pymc.pytensorf import set_function_rngs
from pymc.step_methods.compound import (
BlockedStep,
CompoundStep,
StatsBijection,
get_stats_dtypes_shapes_from_steps,
)
from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name
from pymc.util import (
UNSET,
_UnsetType,
get_default_varnames,
get_random_generator,
is_transformed_name,
)

try:
from zarr.storage import BaseStore, default_compressor
Expand Down Expand Up @@ -398,6 +405,7 @@ def init_trace(
model: Model | None = None,
vars: Sequence[TensorVariable] | None = None,
test_point: dict[str, np.ndarray] | None = None,
rng: np.random.Generator | None = None,
):
"""Initialize the trace groups and arrays.
Expand Down Expand Up @@ -437,6 +445,12 @@ def init_trace(
This is not used and is a product of the inheritance of :class:`ZarrChain`
from :class:`~.BaseTrace`, which uses it to determine the shape and dtype
of `vars`.
rng : numpy.random.Generator | None
A random generator to use to seed the shared random generators that are
present in the pytensor function that maps samples drawn by step methods
onto samples in the posterior trace. Note that this only does anything
if there are deterministic variables that are generated by raw pytensor
random variables.
"""
if self._is_base_setup:
raise RuntimeError("The ZarrTrace has already been initialized") # pragma: no cover
Expand Down Expand Up @@ -534,9 +548,9 @@ def init_trace(
test_point=test_point,
stats_bijection=StatsBijection(step.stats_dtypes),
draws_per_chunk=self.draws_per_chunk,
fn=self.fn,
fn=set_function_rngs(self.fn, rng_),
)
for _ in range(chains)
for rng_ in get_random_generator(rng).spawn(chains)
]
for chain, strace in enumerate(self.straces):
strace.setup(draws=tune + draws, chain=chain, sampler_vars=None)
Expand Down
57 changes: 54 additions & 3 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import warnings

from collections.abc import Callable, Generator, Iterable, Sequence
from typing import cast
from typing import cast, overload

import numpy as np
import pandas as pd
import pytensor
import pytensor.tensor as pt
import scipy.sparse as sps

from pytensor import shared
from pytensor.compile import Function, Mode, get_mode
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import grad
Expand All @@ -42,7 +43,7 @@
from pytensor.tensor.basic import _as_tensor_variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
from pytensor.tensor.rewriting.shape import ShapeFeature
Expand All @@ -51,7 +52,7 @@
from pytensor.tensor.variable import TensorVariable

from pymc.exceptions import NotConstantValueError
from pymc.util import makeiter
from pymc.util import RandomGeneratorState, makeiter, random_generator_from_state
from pymc.vartypes import continuous_types, isgenerator, typefilter

PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable
Expand Down Expand Up @@ -1163,3 +1164,53 @@ def normalize_rng_param(rng: None | Variable) -> Variable:
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)
return rng


@overload
def set_function_rngs(
fn: PointFunc, rng: np.random.Generator | RandomGeneratorState
) -> PointFunc: ...


@overload
def set_function_rngs(
fn: Function, rng: np.random.Generator | RandomGeneratorState
) -> Function: ...


def set_function_rngs(fn: Function, rng: np.random.Generator | RandomGeneratorState) -> Function:
"""Copy a compiled pytensor function and replace the random Generators with spawns.
Parameters
----------
fn : pytensor.compile.function.types.Function | pymc.util.PointFunc
The compiled function
rng : numpy.random.Generator | RandomGeneratorState
The random generator or its state
Returns
-------
fn_out : pytensor.compile.function.types.Function | pymc.pytensorf.PointFunc
A copy of the input function with the shared random generator states set to
spawns of the supplied ``rng``. If the function has no shared random generators
in it, the input ``fn`` is returned without any changes.
If ``fn`` is a :clas:`~pymc.pytensorf.PointFunc` instance, and the inner
pytensor function has random variables, then the inner pytensor function is
copied, setting new random generators, and a new ``PointFunc`` instance is
returned.
"""
# Copy the function and replace any shared RNGs
# This is needed so that it can work correctly with multiple traces
# This will be costly if set_rng is called too often!
rng_gen = rng if isinstance(rng, np.random.Generator) else random_generator_from_state(rng)
fn_ = fn.f if isinstance(fn, PointFunc) else fn
shared_rngs = [var for var in fn_.get_shared() if isinstance(var.type, RandomGeneratorType)]
n_shared_rngs = len(shared_rngs)
swap = {
old_shared_rng: shared(rng, borrow=True)
for old_shared_rng, rng in zip(shared_rngs, rng_gen.spawn(n_shared_rngs), strict=True)
}
if isinstance(fn, PointFunc):
return PointFunc(fn.f.copy(swap=swap)) if n_shared_rngs > 0 else fn
else:
return fn.copy(swap=swap) if n_shared_rngs > 0 else fn
1 change: 1 addition & 0 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ def joined_blas_limiter():
initial_point=initial_points[0],
model=model,
tune=tune,
rng=rngs[0].spawn(1)[0],
)

sample_args = {
Expand Down
29 changes: 29 additions & 0 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,3 +909,32 @@ def test_sample(self, seeded_test):
np.testing.assert_allclose(
x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
)


@pytest.fixture(scope="function", params=[None, "mcbackend", "zarr"])
def trace_backend(request):
if request.param is None:
return None
elif request.param == "mcbackend":
try:
import mcbackend as mcb
except ImportError:
pytest.skip("Requires McBackend to be installed.")
return mcb.NumPyBackend()
elif request.param == "zarr":
try:
trace = pm.backends.zarr.ZarrTrace()
except RuntimeError:
pytest.skip("Requires zarr to be installed")
return trace


def test_random_deterministics(trace_backend):
with pm.Model() as m:
x = pm.Bernoulli("x", p=0.5) * 0 # Force it to be zero
pm.Deterministic("y", x + pm.Normal.dist())

idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)

assert idata1.posterior.equals(idata2.posterior)

0 comments on commit 312effd

Please sign in to comment.