Skip to content

Commit

Permalink
Optimise activity grad for jit
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 24, 2024
1 parent 65b652e commit 1d35387
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions jpc/_core/_grads.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Functions to compute gradients of the free energy."""

from jax import grad
from jax.tree_util import tree_map
from equinox import filter_grad
from jaxtyping import PyTree, ArrayLike, Array
from typing import Union, Tuple, Callable, Optional
from typing import Tuple, Callable, Optional
from ._energies import pc_energy_fn


def _neg_activity_grad(
t: Union[float, int],
energy_fn: Callable,
t: float | int,
activities: PyTree[ArrayLike],
args: Tuple[Optional[PyTree[Callable]], ArrayLike, ArrayLike]
) -> PyTree[Array]:
Expand All @@ -22,14 +24,14 @@ def _neg_activity_grad(
**Main arguments:**
- `pc_energy_fn`: Free energy to take the gradient of.
- `t`: Time step of the ODE system, used for downstream integration by
`diffrax.diffeqsolve`.
- `activities`: List of activities for each layer free to vary.
- `args`: 4-Tuple with
- `args`: 3-Tuple with
(i) list of callable layers of the generative model,
(ii) optional list of callable layers of a network to amortise inference,
(iii) network output (observation), and
(iv) network input (prior).
(ii) network output (observation), and
(iii) network input (prior).
**Returns:**
Expand All @@ -43,7 +45,7 @@ def _neg_activity_grad(
output,
input
)
return [-dFdz for dFdz in dFdzs]
return tree_map(lambda dFdz: -dFdz, dFdzs)


def compute_pc_param_grads(
Expand Down

0 comments on commit 1d35387

Please sign in to comment.