From 4b2e64e34e672212e4fe947674c508a410de9ef0 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 16 Nov 2024 12:58:26 +0530 Subject: [PATCH] bringing back PRNGKey instead of key, till the python311 branch is merged --- algorithmic_efficiency/random_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 93dc263bd..bcfc59c92 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.key(seed) + return jax_rng.PRNGKey(seed) return _PRNGKey(seed)