Skip to content

Commit

Permalink
Add functionality to compute energies during inference
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 25, 2024
1 parent 6e4f94c commit e27013c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 50 deletions.
2 changes: 1 addition & 1 deletion jpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_fc_network as get_fc_network,
compute_accuracy as compute_accuracy,
get_t_max as get_t_max,
compute_pc_infer_energies as compute_pc_infer_energies
compute_infer_energies as compute_infer_energies
)
from ._train import (
make_pc_step as make_pc_step,
Expand Down
13 changes: 7 additions & 6 deletions jpc/_core/_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from jax.numpy import sum, array
from jax import vmap
from jax.tree_util import tree_map
from jaxtyping import PyTree, ArrayLike, Scalar
from typing import Callable, Optional, Union

Expand Down Expand Up @@ -49,25 +50,25 @@ def pc_energy_fn(
"""
batch_size = output.shape[0]
start_activity_l = 1 if input is not None else 2
n_activity_layers = len(activities)-1
n_layers = len(network)-1
n_activity_layers = len(activities) - 1
n_layers = len(network) - 1

eL = output - vmap(network[-1])(activities[-2])
energies = [0.5 * sum(eL ** 2)]
energies = [sum(eL ** 2)]

for act_l, net_l in zip(
range(start_activity_l, n_activity_layers),
range(1, n_layers)
):
err = activities[act_l] - vmap(network[net_l])(activities[act_l-1])
energies.append(0.5 * sum(err ** 2))
energies.append(sum(err ** 2))

e1 = activities[0] - vmap(network[0])(input) if (
input is not None
) else activities[1] - vmap(network[0])(activities[0])
energies.append(0.5 * sum(e1 ** 2))
energies.append(sum(e1 ** 2))

if record_layers:
return [energy_l / batch_size for energy_l in energies]
return tree_map(lambda energy: energy / batch_size, energies)
else:
return sum(array(energies)) / batch_size
67 changes: 30 additions & 37 deletions jpc/_train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""High-level API to train neural networks with predictive coding."""

import equinox as eqx
from jax import vmap
from jax.numpy import mean
from jax import vmap, lax
from jax.tree_util import tree_map
from jax.numpy import mean, array
from diffrax import (
AbstractSolver,
AbstractStepSizeController,
Expand All @@ -13,11 +14,12 @@
init_activities_with_ffwd,
init_activities_with_amort,
solve_pc_activities,
compute_pc_param_grads
compute_pc_param_grads,
get_t_max
)
from optax import GradientTransformationExtraArgs, OptState
from jaxtyping import PyTree, ArrayLike, Scalar, Array
from typing import Callable, Optional, Union, Tuple
from typing import Callable, Optional, Tuple


@eqx.filter_jit
Expand All @@ -32,21 +34,14 @@ def make_pc_step(
n_iters: Optional[int] = 20,
stepsize_controller: AbstractStepSizeController = ConstantStepSize(),
record_activities: bool = False
) -> Union[
Tuple[
PyTree[Callable],
GradientTransformationExtraArgs,
OptState,
Scalar,
PyTree[Array]
],
Tuple[
PyTree[Callable],
GradientTransformationExtraArgs,
OptState,
Scalar
]
]:
) -> Tuple[
PyTree[Callable],
GradientTransformationExtraArgs,
OptState,
Scalar,
PyTree[Array],
Array
]:
"""Updates network parameters with predictive coding.
**Main arguments:**
Expand All @@ -69,8 +64,8 @@ def make_pc_step(
**Returns:**
Network with updated weights, optimiser, optimiser state, training loss and
optionally activities during inference.
Network with updated weights, optimiser, optimiser state, training loss,
equilibrated activities and last inference step.
"""
activities = init_activities_with_ffwd(network=network, input=input)
Expand All @@ -86,9 +81,14 @@ def make_pc_step(
dt=dt,
record_iters=record_activities
)
t_max = lax.cond(
record_activities,
lambda: get_t_max(equilib_activities),
lambda: array(0)
)
param_grads = compute_pc_param_grads(
network=network,
activities=[act[-1] for act in equilib_activities],
activities=tree_map(lambda act: act[t_max], equilib_activities),
output=output,
input=input
)
Expand All @@ -98,21 +98,14 @@ def make_pc_step(
params=network
)
network = eqx.apply_updates(model=network, updates=updates)
if record_activities:
return (
network,
optim,
opt_state,
train_mse_loss,
equilib_activities
)
else:
return (
network,
optim,
opt_state,
train_mse_loss
)
return (
network,
optim,
opt_state,
train_mse_loss,
equilib_activities,
t_max
)


@eqx.filter_jit
Expand Down
12 changes: 6 additions & 6 deletions jpc/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
from jax.numpy import tanh, mean, argmax
from jax.tree_util import tree_map
import equinox as eqx
import equinox.nn as nn
from jpc import pc_energy_fn
Expand Down Expand Up @@ -78,13 +79,12 @@ def compute_accuracy(truths: ArrayLike, preds: ArrayLike) -> Scalar:
)


def get_t_max(activities_iters: PyTree[Array]) -> int:
t_max = argmax(activities_iters[0][:, 0, 0])-1
return int(t_max)
def get_t_max(activities_iters: PyTree[Array]) -> Array:
return argmax(activities_iters[0][:, 0, 0]) - 1


@eqx.filter_jit
def compute_pc_infer_energies(
def compute_infer_energies(
network: PyTree[Callable],
activities_iters: PyTree[Array],
t_max: int,
Expand Down Expand Up @@ -119,12 +119,12 @@ def compute_pc_infer_energies(
if t % compute_every == 0:
energies = pc_energy_fn(
network=network,
activities=[act[t] for act in activities_iters],
activities=tree_map(lambda act: act[t], activities_iters),
output=output,
input=input,
record_layers=True
)
for l in range(len(network)):
energies_iters[l].append(energies[l])

return energies_iters
return energies_iters[::-1]

0 comments on commit e27013c

Please sign in to comment.