Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 28, 2024
1 parent 371adc9 commit f8ff3a3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jpc/_core/_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def pc_energy_fn(
!!! note
The input and output correspond to the prior and observation of
The input x and output y correspond to the prior and observation of
the generative model, respectively.
**Main arguments:**
- `model`: List of callable model (e.g. neural network) layers.
- `activities`: List of activities for each layer free to vary.
- `output`: Observation or target of the generative model.
- `input`: Optional prior of the generative model.
- `y`: Observation or target of the generative model.
- `x`: Optional prior of the generative model.
**Other arguments:**
Expand Down
1 change: 1 addition & 0 deletions jpc/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def loop_body(state):
energies_iters = energies_iters.at[:, t].set(energies)
return t + 1, energies_iters

# 4096 is the max number of steps set in diffrax
energies_iters = zeros((len(model), 4096))
_, energies_iters = jax.lax.while_loop(
lambda state: state[0] < t_max,
Expand Down

0 comments on commit f8ff3a3

Please sign in to comment.