Skip to content

Commit

Permalink
Fix hybrid pc
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 17, 2024
1 parent 07b3c1c commit b5f914b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 95 deletions.
80 changes: 8 additions & 72 deletions jpc/_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,65 +6,6 @@
from typing import Callable, Optional


def _energy_fn(
generator: PyTree[Callable],
activities: PyTree[ArrayLike],
output: ArrayLike,
input: Optional[ArrayLike] = None,
amortiser: Optional[PyTree[Callable]] = None
) -> Scalar:
"""Computes the free energy for a 'hybrid' or standard predictive coding network.
**Main arguments:**
- `generator`: List of callable layers for the generative model.
- `activities`: List of activities for each layer free to vary.
- `output`: Observation or target of the generative model.
- `input`: Optional prior of the generative model.
**Other arguments:**
- `amortiser`: Optional list of callable layers for a network amortising
the inference of the generative model.
**Returns:**
The total energy normalised by batch size.
"""
batch_size = output.shape[0]
start_activity_l = 1 if input is not None else 2
n_activity_layers = len(activities)-1
n_layers = len(generator)-1

gen_eL = output - vmap(generator[-1])(activities[-2])
energy = 0.5 * sum(gen_eL ** 2)
if amortiser is not None:
amort_eL = input - vmap(amortiser[-1])(activities[0])
energy += 0.5 * sum(amort_eL ** 2)

for act_l, gen_l, amort_l in zip(
range(start_activity_l, n_activity_layers),
range(1, n_layers),
reversed(range(1, n_layers))
):
gen_err = activities[act_l] - vmap(generator[gen_l])(activities[act_l-1])
energy += 0.5 * sum(gen_err ** 2)
if amortiser is not None:
amort_err = activities[amort_l-1] - vmap(amortiser[gen_l])(activities[amort_l])
energy += 0.5 * sum(amort_err ** 2)

gen_e1 = activities[0] - vmap(generator[0])(input) if (
input is not None
) else activities[1] - vmap(generator[0])(activities[0])
energy += 0.5 * sum(gen_e1 ** 2)
if amortiser is not None:
amort_e1 = activities[-1] - vmap(amortiser[0])(output)
energy += 0.5 * sum(amort_e1 ** 2)

return energy / batch_size


def pc_energy_fn(
network: PyTree[Callable],
activities: PyTree[ArrayLike],
Expand Down Expand Up @@ -102,8 +43,8 @@ def pc_energy_fn(
"""
batch_size = output.shape[0]
start_activity_l = 1 if input is not None else 2
n_activity_layers = len(activities) - 1
n_layers = len(network) - 1
n_activity_layers = len(activities)-1
n_layers = len(network)-1

gen_eL = output - vmap(network[-1])(activities[-2])
energy = 0.5 * sum(gen_eL ** 2)
Expand All @@ -112,7 +53,7 @@ def pc_energy_fn(
range(start_activity_l, n_activity_layers),
range(1, n_layers)
):
gen_err = activities[act_l] - vmap(network[gen_l])(activities[act_l - 1])
gen_err = activities[act_l] - vmap(network[gen_l])(activities[act_l-1])
energy += 0.5 * sum(gen_err ** 2)

gen_e1 = activities[0] - vmap(network[0])(input) if (
Expand All @@ -125,7 +66,6 @@ def pc_energy_fn(

def hpc_energy_fn(
amortiser: PyTree[Callable],
generator: PyTree[Callable],
activities: PyTree[ArrayLike],
output: ArrayLike,
input: ArrayLike
Expand Down Expand Up @@ -155,7 +95,6 @@ def hpc_energy_fn(
- `amortiser`: List of callable layers for network amortising the inference
of the generative model.
- `generator`: List of callable layers for the generative model.
- `activities`: List of activities for each layer free to vary.
- `output`: Observation of the generative model (or input of the amortiser).
- `input`: Prior of the generative model (or output of the amortiser).
Expand All @@ -166,20 +105,17 @@ def hpc_energy_fn(
"""
batch_size = output.shape[0]
n_hidden = len(generator) - 1
n_hidden = len(amortiser) - 1

gen_eL = output - vmap(generator[-1])(activities[-2])
amort_eL = input - vmap(amortiser[-1])(activities[0])
energy = 0.5 * sum(amort_eL ** 2) + 0.5 * sum(gen_eL ** 2)
energy = 0.5 * sum(amort_eL ** 2)

for l, rev_l in zip(range(1, n_hidden), reversed(range(1, n_hidden))):
gen_err = activities[l] - vmap(generator[l])(activities[l-1])
amort_err = activities[rev_l-1] - vmap(amortiser[l])(activities[rev_l])
energy += 0.5 * sum(gen_err ** 2) + 0.5 * sum(amort_err ** 2)
energy += 0.5 * sum(amort_err ** 2)

gen_e1 = activities[0] - vmap(generator[0])(input)
amort_e1 = activities[-1] - vmap(amortiser[0])(output)
energy += 0.5 * sum(amort_e1 ** 2) + 0.5 * sum(gen_e1 ** 2)
amort_e1 = activities[-2] - vmap(amortiser[0])(output)
energy += 0.5 * sum(amort_e1 ** 2)

return energy / batch_size

Expand Down
26 changes: 9 additions & 17 deletions jpc/_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from equinox import filter_grad
from jaxtyping import PyTree, ArrayLike, Array
from typing import Union, Tuple, Callable, Optional
from ._energies import _energy_fn, pc_energy_fn, _lateral_energy_fn
from ._energies import pc_energy_fn, hpc_energy_fn, _lateral_energy_fn


def _neg_activity_grad(
t: Union[float, int],
activities: PyTree[ArrayLike],
args: Tuple[PyTree[Callable], Optional[PyTree[Callable]], ArrayLike, ArrayLike]
args: Tuple[Optional[PyTree[Callable]], ArrayLike, ArrayLike]
) -> PyTree[Array]:
"""Computes the negative gradient of the energy with respect to the activities.
Expand All @@ -36,13 +36,12 @@ def _neg_activity_grad(
List of negative gradients of the energy w.r.t the activities.
"""
amortiser, generator, output, input = args
dFdzs = grad(_energy_fn, argnums=1)(
generator, output, input = args
dFdzs = grad(pc_energy_fn, argnums=1)(
generator,
activities,
output,
input,
amortiser
input
)
return [-dFdz for dFdz in dFdzs]

Expand Down Expand Up @@ -109,7 +108,6 @@ def compute_pc_param_grads(


def compute_gen_param_grads(
amortiser: PyTree[Callable],
generator: PyTree[Callable],
activities: PyTree[ArrayLike],
output: ArrayLike,
Expand All @@ -125,8 +123,6 @@ def compute_gen_param_grads(
**Main arguments:**
- `amortiser`: List of callable layers for the network amortising the
inference of the generative model.
- `generator`: List of callable layers for the generative model.
- `activities`: List of activities for each layer free to vary.
- `output`: Observation or target of the generative model.
Expand All @@ -137,9 +133,8 @@ def compute_gen_param_grads(
List of parameter gradients for each layer of the generative network.
"""
return filter_grad(_energy_fn)(
return filter_grad(pc_energy_fn)(
generator,
amortiser,
activities,
output,
input
Expand All @@ -148,7 +143,6 @@ def compute_gen_param_grads(

def compute_amort_param_grads(
amortiser: PyTree[Callable],
generator: PyTree[Callable],
activities: PyTree[ArrayLike],
output: ArrayLike,
input: ArrayLike
Expand All @@ -159,7 +153,6 @@ def compute_amort_param_grads(
- `amortiser`: List of callable layers for a network amortising the
inference of the generative model.
- `generator`: List of callable layers for the generative network.
- `activities`: List of activities for each layer free to vary.
- `output`: Observation or target of the generative model.
- `input`: Optional prior of the generative model.
Expand All @@ -169,12 +162,11 @@ def compute_amort_param_grads(
List of parameter gradients for each layer of the amortiser.
"""
return filter_grad(_energy_fn, argnum=4)(
generator,
return filter_grad(hpc_energy_fn)(
amortiser,
activities,
output,
input,
amortiser
input
)


Expand Down
9 changes: 3 additions & 6 deletions jpc/_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@


def solve_pc_activities(
generator: PyTree[Callable],
network: PyTree[Callable],
activities: PyTree[ArrayLike],
output: ArrayLike,
input: Optional[ArrayLike] = None,
amortiser: Optional[PyTree[Callable]] = None,
solver: AbstractSolver = Dopri5(),
n_iters: int = 300,
stepsize_controller: AbstractStepSizeController = PIDController(
Expand All @@ -44,15 +43,13 @@ def solve_pc_activities(
**Main arguments:**
- `generator`: List of callable layers for the generative model.
- `network`: List of callable layers for the generative model.
- `activities`: List of activities for each layer free to vary.
- `output`: Observation or target of the generative model.
- `input`: Optional prior of the generative model.
**Other arguments:**
- `amortiser`: Optional list of callable layers for a network amortising
the inference of the generative model.
- `solver`: Diffrax (ODE) solver to be used. Default is Dopri5.
- `n_iters`: Number of integration steps (300 as default).
- `stepsize_controller`: diffrax controllers for step size integration.
Expand All @@ -74,7 +71,7 @@ def solve_pc_activities(
t1=n_iters,
dt0=dt,
y0=activities,
args=(amortiser, generator, output, input),
args=(network, output, input),
stepsize_controller=stepsize_controller,
saveat=SaveAt(t1=True, steps=record_iters)
)
Expand Down
30 changes: 30 additions & 0 deletions jpc/_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,33 @@ def init_activities_from_gaussian(
)
)
return activities


def amort_init(
amortiser: PyTree[Callable],
generator: PyTree[Callable],
output: ArrayLike
) -> PyTree[Array]:
"""Initialises layers' activity using an amortised network.
**Main arguments:**
- `amortiser`: List of callable layers for network amortising the inference
of the generative model.
- `generator`: List of callable layers for the generative model.
- `output`: Input to the amortiser.
**Returns:**
List with amortised initialisation of each layer.
"""
activities = [vmap(amortiser[0])(output)]
for l in range(1, len(amortiser)):
activities.append(vmap(amortiser[l])(activities[l - 1]))

activities = activities[::-1]
activities.append(
vmap(generator[-1])(activities[-1])
)
return activities

0 comments on commit b5f914b

Please sign in to comment.