diff --git a/quax/examples/prng/_core.py b/quax/examples/prng/_core.py index 046ba75..683ae59 100644 --- a/quax/examples/prng/_core.py +++ b/quax/examples/prng/_core.py @@ -1,7 +1,7 @@ import abc import functools as ft from collections.abc import Sequence -from typing import Any, TypeVar +from typing import Any, TypeVar, Union from typing_extensions import Self, TYPE_CHECKING, TypeAlias import equinox as eqx @@ -52,8 +52,12 @@ class ThreeFry(PRNG): value: UInt32[Array, "*batch 2"] - def __init__(self, seed: Integer[ArrayLike, ""]): - self.value = jax._src.prng.threefry_seed(jnp.asarray(seed)) + def __init__(self, seed: Union[Integer[ArrayLike, ""], "ThreeFry"]): + self.value = ( + jax._src.prng.threefry_seed(jnp.asarray(seed)) + if not isinstance(seed, ThreeFry) + else seed.value + ) def aval(self): *shape, _ = self.value.shape diff --git a/tests/test_prng.py b/tests/test_prng.py index fa268f5..e637639 100644 --- a/tests/test_prng.py +++ b/tests/test_prng.py @@ -61,3 +61,9 @@ def body(carry, _): return cumvals run(prng.ThreeFry(0)) + + +def test_threefry_from_threefry(): + key = prng.ThreeFry(0) + new_key = prng.ThreeFry(key) + assert key == new_key