From eb76127942b672f2b130ca00de1364dad82a4ea7 Mon Sep 17 00:00:00 2001 From: Francesco Innocenti Date: Mon, 24 Jun 2024 13:30:02 +0100 Subject: [PATCH] Fix bug in activity grad --- jpc/_core/_grads.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jpc/_core/_grads.py b/jpc/_core/_grads.py index 7c77d1d..3ac31f9 100644 --- a/jpc/_core/_grads.py +++ b/jpc/_core/_grads.py @@ -9,10 +9,10 @@ def _neg_activity_grad( - energy_fn: Callable, t: float | int, activities: PyTree[ArrayLike], - args: Tuple[Optional[PyTree[Callable]], ArrayLike, ArrayLike] + args: Tuple[Optional[PyTree[Callable]], ArrayLike, ArrayLike], + energy_fn: Callable = pc_energy_fn, ) -> PyTree[Array]: """Computes the negative gradient of the energy with respect to the activities. @@ -24,7 +24,6 @@ def _neg_activity_grad( **Main arguments:** - - `pc_energy_fn`: Free energy to take the gradient of. - `t`: Time step of the ODE system, used for downstream integration by `diffrax.diffeqsolve`. - `activities`: List of activities for each layer free to vary. @@ -32,6 +31,7 @@ def _neg_activity_grad( (i) list of callable layers of the generative model, (ii) network output (observation), and (iii) network input (prior). + - `pc_energy_fn`: Free energy to take the gradient of. **Returns:**