Skip to content

Commit

Permalink
Fix bug in activity grad
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 24, 2024
1 parent 79d24ea commit eb76127
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jpc/_core/_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@


def _neg_activity_grad(
energy_fn: Callable,
t: float | int,
activities: PyTree[ArrayLike],
args: Tuple[Optional[PyTree[Callable]], ArrayLike, ArrayLike]
args: Tuple[Optional[PyTree[Callable]], ArrayLike, ArrayLike],
energy_fn: Callable = pc_energy_fn,
) -> PyTree[Array]:
"""Computes the negative gradient of the energy with respect to the activities.
Expand All @@ -24,14 +24,14 @@ def _neg_activity_grad(
**Main arguments:**
- `pc_energy_fn`: Free energy to take the gradient of.
- `t`: Time step of the ODE system, used for downstream integration by
`diffrax.diffeqsolve`.
- `activities`: List of activities for each layer free to vary.
- `args`: 3-Tuple with
(i) list of callable layers of the generative model,
(ii) network output (observation), and
(iii) network input (prior).
- `pc_energy_fn`: Free energy to take the gradient of.
**Returns:**
Expand Down

0 comments on commit eb76127

Please sign in to comment.