Skip to content

Commit

Permalink
test_evals_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
harneet862 committed Nov 17, 2024
1 parent 0e646be commit 34d917b
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@

# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
# unsigned int), while RandomState.randint only accepts and returns signed ints.
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32
MAX_UINT32 = 2**32-1
MIN_UINT32 = 0

SeedType = Union[int, list, np.ndarray]


def _signed_to_unsigned(seed: SeedType) -> SeedType:
if isinstance(seed, int):
return seed % 2**32
return seed % MAX_UINT32
if isinstance(seed, list):
return [s % 2**32 for s in seed]
return [s % MAX_UINT32 for s in seed]
if isinstance(seed, np.ndarray):
return np.array([s % 2**32 for s in seed.tolist()])
return np.array([s % MAX_UINT32 for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
return [new_seed, data]


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand Down Expand Up @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType:
def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
if FLAGS.framework == 'jax':
_check_jax_install()
return jax_rng.PRNGKey(seed)
return jax_rng.key(seed)
return _PRNGKey(seed)

0 comments on commit 34d917b

Please sign in to comment.