Skip to content

Commit

Permalink
Remove lateral pc as experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 17, 2024
1 parent b5f914b commit bf1d344
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 169 deletions.
53 changes: 1 addition & 52 deletions jpc/_energies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Energy functions for predictive coding networks."""

from jax.numpy import sum, log, exp
from jax.numpy import sum
from jax import vmap
from jaxtyping import PyTree, ArrayLike, Scalar
from typing import Callable, Optional
Expand Down Expand Up @@ -118,54 +118,3 @@ def hpc_energy_fn(
energy += 0.5 * sum(amort_e1 ** 2)

return energy / batch_size


def _lateral_energy_fn(
amortiser: PyTree[Callable],
activities: PyTree[ArrayLike],
outputs: PyTree[ArrayLike],
) -> Scalar:
"""Computes the free energy for a predictive coding network with lateral connections.
!!! note
This is currently experimental.
**Main arguments:**
- `amortiser`: List of callable layers for an amortised network.
- `activities`: List of activities for each layer free to vary, one list
per branch (n=2).
- `outputs`: List of two inputs to the amortiser, one for each branch.
**Returns:**
The total energy normalised by batch size.
"""
activities1, activities2 = activities
output1, output2 = outputs
batch_size = output1.shape[0]
n_layers = len(amortiser)

amort_e1 = activities1[-1] - vmap(amortiser[0])(output1)
amort_e12 = activities2[-1] - vmap(amortiser[0])(output2)
energy = 0.5 * sum(amort_e1 ** 2) + 0.5 * sum(amort_e12 ** 2)

lateral1 = activities1[-1] - activities2[-1]
lateralL = activities1[0] - activities2[0]
energy += 0.5 * sum(lateral1 ** 2) + 0.5 * sum(lateralL ** 2)

for l, rev_l in zip(range(1, n_layers), reversed(range(1, n_layers))):
amort_err = activities1[rev_l-1] - vmap(amortiser[l])(activities1[rev_l])
amort_err2 = activities2[rev_l-1] - vmap(amortiser[l])(activities2[rev_l])
energy += 0.5 * sum(amort_err ** 2) + 0.5 * sum(amort_err2 ** 2)

lateral_err = activities2[l] - activities2[l]
energy += 0.5 * sum(lateral_err ** 2)

logsumexp = log(sum(exp(-amort_err**2), axis=0))
logsumexp2 = log(sum(exp(-amort_err2**2), axis=0))
energy += 0.5 * sum(logsumexp) + 0.5 * sum(logsumexp2)

return energy / batch_size
61 changes: 1 addition & 60 deletions jpc/_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from equinox import filter_grad
from jaxtyping import PyTree, ArrayLike, Array
from typing import Union, Tuple, Callable, Optional
from ._energies import pc_energy_fn, hpc_energy_fn, _lateral_energy_fn
from ._energies import pc_energy_fn, hpc_energy_fn


def _neg_activity_grad(
Expand Down Expand Up @@ -46,39 +46,6 @@ def _neg_activity_grad(
return [-dFdz for dFdz in dFdzs]


def _neg_lateral_activity_grad(
t: Union[float, int],
activities: PyTree[ArrayLike],
args: Tuple[PyTree[Callable], PyTree[ArrayLike], PyTree[ArrayLike]]
) -> PyTree[Array]:
"""Same as `_neg_activity_grad` but for a network with lateral connections.
**Main arguments:**
- `t`: Time step of the ODE system, used for downstream integration by
`diffrax.diffeqsolve`.
- `activities`: List of activities for each layer free to vary, one list
per branch (n=2).
- `args`: 2-Tuple with
(i) list of callable layers for amortised network, and
(ii) network outpus (observations), one for each branch.
**Returns:**
List of negative gradients of the energy w.r.t the activities.
"""
amortiser, outputs = args
dFdzs = grad(_lateral_energy_fn, argnums=1)(
amortiser,
activities,
outputs
)
for branch in range(2):
dFdzs[branch] = [-dFdz for dFdz in dFdzs[branch]]
return dFdzs


def compute_pc_param_grads(
network: PyTree[Callable],
activities: PyTree[ArrayLike],
Expand Down Expand Up @@ -168,29 +135,3 @@ def compute_amort_param_grads(
output,
input
)


def compute_lateral_pc_param_grads(
amortiser: PyTree[Callable],
activities: PyTree[ArrayLike],
outputs: PyTree[ArrayLike],
) -> PyTree[Array]:
"""Same as `compute_pc_param_grads` but for a network with lateral connections.
**Main arguments:**
- `amortiser`: List of callable layers for an amortised network.
- `activities`: List of activities for each layer free to vary, one list
per branch (n=2).
- `outputs`: List of two inputs to the amortiser, one for each branch.
**Returns:**
List of parameter gradients for each network layer.
"""
return filter_grad(_lateral_energy_fn)(
amortiser,
activities,
outputs
)
60 changes: 3 additions & 57 deletions jpc/_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from jaxtyping import PyTree, ArrayLike, Array
from typing import Callable, Optional, Union
from ._grads import _neg_activity_grad, _neg_lateral_activity_grad
from ._grads import _neg_activity_grad
from diffrax import (
AbstractSolver,
AbstractStepSizeController,
Expand All @@ -22,8 +22,8 @@ def solve_pc_activities(
solver: AbstractSolver = Dopri5(),
n_iters: int = 300,
stepsize_controller: AbstractStepSizeController = PIDController(
rtol=1e-5,
atol=1e-5
rtol=1e-3,
atol=1e-3
),
dt: Union[float, int] = None,
record_iters: bool = False
Expand Down Expand Up @@ -76,57 +76,3 @@ def solve_pc_activities(
saveat=SaveAt(t1=True, steps=record_iters)
)
return sol.ys if record_iters else [activity[0] for activity in sol.ys]


def solve_lateral_pc_activities(
amortiser: PyTree[Callable],
activities: PyTree[ArrayLike],
outputs: PyTree[ArrayLike],
solver: AbstractSolver = Dopri5(),
n_iters: int = 300,
stepsize_controller: AbstractStepSizeController = PIDController(
rtol=1e-5,
atol=1e-5
),
dt: Union[float, int] = None,
record_iters: bool = False
) -> PyTree[Array]:
"""Same as `solve_pc_activities` but for a network with lateral connections.
**Main arguments:**
- `amortiser`: List of callable layers for an amortised network.
- `activities`: List of activities for each layer free to vary, one list
per branch (n=2).
- `outputs`: List of two inputs to the amortiser, one for each branch.
**Other arguments:**
- `solver`: Diffrax (ODE) solver to be used. Default is Dopri5.
- `n_iters`: Number of integration steps (300 as default).
- `stepsize_controller`: diffrax controllers for step size integration.
Defaults to `PIDController`.
- `dt`: Integration step size. Defaults to None, since step size is
automatically determined by the default `PIDController`.
- `record_iters`: If `True`, returns all integration steps. `False` by
default.
**Returns:**
List with solution of the activity dynamics for each branch and layer.
"""
sol = diffeqsolve(
terms=ODETerm(_neg_lateral_activity_grad),
solver=solver,
t0=0,
t1=n_iters,
dt0=dt,
y0=activities,
args=(amortiser, outputs),
stepsize_controller=stepsize_controller,
saveat=SaveAt(t1=True, steps=record_iters)
)
return sol.ys if record_iters else [
[activity[0] for activity in branch] for branch in sol.ys
]

0 comments on commit bf1d344

Please sign in to comment.