diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 93dc263bd..bcfc59c92 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.key(seed) + return jax_rng.PRNGKey(seed) return _PRNGKey(seed)