Skip to content

Commit

Permalink
Jit compute_pc_infer_energies and add comment on unsupervised training
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 21, 2024
1 parent 4ce2537 commit d8bd2ab
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/generative_pc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@
"\n",
"A PC network can be trained in a single line of code with `jpc.make_pc_step()`. See the documentation for more. Similarly, we can use `jpc.test_generative_pc()` to get some test metrics including accuracy of inferred labels and image predictions. Note that these functions are already \"jitted\" for performance.\n",
"\n",
"Below we simply wrap each of these functions in our training and test loops, respectively."
"Below we simply wrap each of these functions in our training and test loops, respectively. Note that to train in an unsupervised way, you can simply remove the `input` from `jpc.make_pc_step()` and the `evaluate()` script."
]
},
{
Expand Down Expand Up @@ -341,7 +341,7 @@
" optim=optim,\n",
" opt_state=opt_state,\n",
" output=img_batch,\n",
" input=label_batch,\n",
" #input=label_batch,\n",
" n_iters=n_infer_iters\n",
" )\n",
" if ((iter+1) % test_every) == 0:\n",
Expand Down Expand Up @@ -441,4 +441,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
2 changes: 2 additions & 0 deletions jpc/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
from jax.numpy import tanh, mean, argmax
import equinox as eqx
import equinox.nn as nn
from jpc import pc_energy_fn
from jaxtyping import PRNGKeyArray, PyTree, ArrayLike, Scalar, Array
Expand Down Expand Up @@ -82,6 +83,7 @@ def get_t_max(activities_iters: PyTree[Array]) -> int:
return int(t_max)


@eqx.filter_jit
def compute_pc_infer_energies(
network: PyTree[Callable],
activities_iters: PyTree[Array],
Expand Down

0 comments on commit d8bd2ab

Please sign in to comment.