Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change deprecated
jax.tree_map
to avoid warnings:
``` DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version). ``` PiperOrigin-RevId: 622276898
- Loading branch information