Skip to content

Commit

Permalink
Add functionality to make_pc_step for unsupervised training
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 25, 2024
1 parent e27013c commit fffafd0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 16 deletions.
8 changes: 3 additions & 5 deletions jpc/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def test_generative_pc(
- `key`: `jax.random.PRNGKey` for random initialisation of activities.
- `layer_sizes`: Dimension of all layers (input, hidden and output).
- `batch_size`: Dimension of data batch for random initialisation of
activities.
- `batch_size`: Dimension of data batch for activity initialisation.
- `network`: List of callable network layers.
- `output`: Observation or target of the generative model.
- `input`: Optional prior of the generative model.
Expand Down Expand Up @@ -132,9 +131,8 @@ def test_hpc(
- `key`: `jax.random.PRNGKey` for random initialisation of activities.
- `layer_sizes`: Dimension of all layers (input, hidden and output).
- `batch_size`: Dimension of data batch for random initialisation of
activities.
- `generator`: List of callable layers for the generative network..
- `batch_size`: Dimension of data batch for initialisation of activities.
- `generator`: List of callable layers for the generative network.
- `amortiser`: List of callable layers for network amortising the inference
of the generative model.
- `output`: Observation or target of the generative model.
Expand Down
53 changes: 42 additions & 11 deletions jpc/_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""High-level API to train neural networks with predictive coding."""

import equinox as eqx
from jax import vmap, lax
from jax import vmap
from jax.tree_util import tree_map
from jax.numpy import mean, array
from diffrax import (
Expand All @@ -12,13 +12,14 @@
)
from jpc import (
init_activities_with_ffwd,
init_activities_from_gaussian,
init_activities_with_amort,
solve_pc_activities,
compute_pc_param_grads,
get_t_max
)
from optax import GradientTransformationExtraArgs, OptState
from jaxtyping import PyTree, ArrayLike, Scalar, Array
from jaxtyping import PyTree, ArrayLike, Scalar, Array, PRNGKeyArray
from typing import Callable, Optional, Tuple


Expand All @@ -33,7 +34,11 @@ def make_pc_step(
dt: float | int = 1,
n_iters: Optional[int] = 20,
stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
record_activities: bool = False
record_activities: bool = False,
key: Optional[PRNGKeyArray] = None,
layer_sizes: Optional[PyTree[int]] = None,
batch_size: Optional[int] = None,
sigma: Scalar = 0.05,
) -> Tuple[
PyTree[Callable],
GradientTransformationExtraArgs,
Expand All @@ -52,6 +57,12 @@ def make_pc_step(
- `output`: Observation or target of the generative model.
- `input`: Optional prior of the generative model.
!!! note
The arguments `key`, `layer_sizes` and `batch_size` must be passed if
`input` is None, since unsupervised training will be assumed and
activities need to be initialised randomly.
**Other arguments:**
- `solver`: Diffrax (ODE) solver to be used. Default is Euler.
Expand All @@ -61,15 +72,36 @@ def make_pc_step(
Defaults to `ConstantStepSize`.
- `record_activities`: If `True`, returns activities at every inference
iteration.
- `key`: `jax.random.PRNGKey` for random initialisation of activities.
- `layer_sizes`: Dimension of all layers (input, hidden and output).
- `batch_size`: Dimension of data batch for activity initialisation.
- `sigma`: Standard deviation for Gaussian to sample activities from for
random initialisation. Defaults to 5e-2.
**Returns:**
Network with updated weights, optimiser, optimiser state, training loss,
equilibrated activities and last inference step.
"""
activities = init_activities_with_ffwd(network=network, input=input)
train_mse_loss = mean((output - activities[-1])**2)
if input is None and any(x is None for x in (key, layer_sizes, batch_size)):
raise ValueError("""
If there is no input, then unsupervised training is assumed, and
`key`, `layer_sizes` and `batch_size` must be passed for random
initialisation of activities.
""")
if input is None:
activities = init_activities_from_gaussian(
key=key,
layer_sizes=layer_sizes,
mode="unsupervised",
batch_size=batch_size,
sigma=sigma
)
else:
activities = init_activities_with_ffwd(network=network, input=input)

train_mse_loss = mean((output - activities[-1])**2) if input is not None else None
equilib_activities = solve_pc_activities(
network=network,
activities=activities,
Expand All @@ -81,14 +113,13 @@ def make_pc_step(
dt=dt,
record_iters=record_activities
)
t_max = lax.cond(
record_activities,
lambda: get_t_max(equilib_activities),
lambda: array(0)
)
t_max = get_t_max(equilib_activities)
param_grads = compute_pc_param_grads(
network=network,
activities=tree_map(lambda act: act[t_max], equilib_activities),
activities=tree_map(
lambda act: act[t_max if record_activities else array(0)],
equilib_activities
),
output=output,
input=input
)
Expand Down

0 comments on commit fffafd0

Please sign in to comment.