From 280645f360a2b9764f2afa7854d9017dd850d917 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Fri, 1 Dec 2023 14:09:39 +0100 Subject: [PATCH] 'jax.tree_util' -> 'jtu' --- src/dilax/model.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/dilax/model.py b/src/dilax/model.py index 40a214e..3dbdf11 100644 --- a/src/dilax/model.py +++ b/src/dilax/model.py @@ -6,6 +6,7 @@ import equinox as eqx import jax import jax.numpy as jnp +import jax.tree_util as jtu from dilax.parameter import Parameter from dilax.util import Sentinel, _NoValue, deep_update @@ -31,7 +32,7 @@ def add(self, process: str, expectation: jax.Array) -> Result: return self def expectation(self) -> jax.Array: - return cast(jax.Array, sum(jax.tree_util.tree_leaves(self.expectations))) + return cast(jax.Array, sum(jtu.tree_leaves(self.expectations))) def _is_parameter(leaf: Any) -> bool: @@ -120,7 +121,7 @@ def __init__( @property def parameter_values(self) -> dict: - return jax.tree_util.tree_map( + return jtu.tree_map( lambda l: l.value, # noqa: E741 self.parameters, is_leaf=_is_parameter, @@ -135,7 +136,7 @@ def _constraint(param: Parameter) -> jax.Array: return next(iter(param.constraints)).logpdf(param.value) return jnp.array([0.0]) - return jax.tree_util.tree_map( + return jtu.tree_map( _constraint, self.parameters, is_leaf=_is_parameter, @@ -160,9 +161,7 @@ def update( # patch original parameters with new ones _updates = deep_update( - jax.tree_util.tree_map( - lambda _: None, self.parameters, is_leaf=_is_parameter - ), + jtu.tree_map(lambda _: None, self.parameters, is_leaf=_is_parameter), values, ) @@ -171,7 +170,7 @@ def _update_params(update: jax.Array | None, param: Parameter) -> Parameter: return param return param.update(value=update) - new_parameters = jax.tree_util.tree_map( + new_parameters = jtu.tree_map( _update_params, _updates, self.parameters, @@ -183,12 +182,17 @@ def _update_params(update: jax.Array | None, param: Parameter) -> Parameter: ) def nll_boundary_penalty(self) -> jax.Array: - params = jax.tree_util.tree_leaves(self.parameters, is_leaf=_is_parameter) - - return sum( - jax.tree_util.tree_map( - lambda p: p.boundary_penalty, params, is_leaf=_is_parameter - ) + return cast( + jax.Array, + sum( + jtu.tree_leaves( + jtu.tree_map( + lambda p: p.boundary_penalty, + self.parameters, + is_leaf=_is_parameter, + ) + ) + ), ) @abc.abstractmethod