Skip to content

Commit

Permalink
feat: convenience construction of a threefry
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Mar 7, 2024
1 parent f5736cf commit 89a85bc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
10 changes: 7 additions & 3 deletions quax/examples/prng/_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import functools as ft
from collections.abc import Sequence
from typing import Any, TypeVar
from typing import Any, TypeVar, Union
from typing_extensions import Self, TYPE_CHECKING, TypeAlias

import equinox as eqx
Expand Down Expand Up @@ -52,8 +52,12 @@ class ThreeFry(PRNG):

value: UInt32[Array, "*batch 2"]

def __init__(self, seed: Integer[ArrayLike, ""]):
self.value = jax._src.prng.threefry_seed(jnp.asarray(seed))
def __init__(self, seed: Union[Integer[ArrayLike, ""], "ThreeFry"]):
self.value = (
jax._src.prng.threefry_seed(jnp.asarray(seed))
if not isinstance(seed, ThreeFry)
else seed.value
)

def aval(self):
*shape, _ = self.value.shape
Expand Down
6 changes: 6 additions & 0 deletions tests/test_prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,9 @@ def body(carry, _):
return cumvals

run(prng.ThreeFry(0))


def test_threefry_from_threefry():
key = prng.ThreeFry(0)
new_key = prng.ThreeFry(key)
assert key == new_key

0 comments on commit 89a85bc

Please sign in to comment.