Skip to content

Commit

Permalink
Restructure package
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 18, 2024
1 parent 8b42cfa commit 243fe37
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 15 deletions.
15 changes: 10 additions & 5 deletions jpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import importlib.metadata

from ._utils import get_fc_network as get_fc_network
from ._init import (
from core._init import (
init_activities_with_ffwd as init_activities_with_ffwd,
init_activities_from_gaussian as init_activities_from_gaussian,
amort_init as amort_init
)
from ._energies import (
from core._energies import (
pc_energy_fn as pc_energy_fn,
hpc_energy_fn as hpc_energy_fn
)
from ._infer import solve_pc_activities as solve_pc_activities
from ._grads import (
from core._infer import solve_pc_activities as solve_pc_activities
from core._grads import (
compute_pc_param_grads as compute_pc_param_grads,
compute_gen_param_grads as compute_gen_param_grads,
compute_amort_param_grads as compute_amort_param_grads
)

from ._utils import get_fc_network as get_fc_network
from ._train import (
make_pc_step as make_pc_step,
make_hpc_step as make_hpc_step
)
from ._test import (
test_generative_pc as test_generative_pc,
test_hpc as test_hpc
)


__version__ = importlib.metadata.version("jpc")
22 changes: 12 additions & 10 deletions jpc/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
Dopri5,
PIDController
)
from optax import GradientTransformationExtraArgs, OptState
from ._init import init_activities_with_ffwd, amort_init
from ._infer import solve_pc_activities
from ._grads import (
from jpc import (
init_activities_with_ffwd,
amort_init,
solve_pc_activities,
compute_pc_param_grads,
compute_gen_param_grads,
compute_amort_param_grads
)

from optax import GradientTransformationExtraArgs, OptState
from jaxtyping import PyTree, ArrayLike, Scalar
from typing import Callable, Optional, Union, Tuple

Expand Down Expand Up @@ -49,8 +49,8 @@ def make_pc_step(
**Other arguments:**
- `solver`: Diffrax (ODE) solver to be used. Default is Dopri5.
- `n_iters`: Number of integration steps (300 as default).
- `stepsize_controller`: diffrax controllers for step size integration.
- `n_iters`: Number of integration steps for inference (300 as default).
- `stepsize_controller`: diffrax controllers for inference integration.
Defaults to `PIDController`.
- `dt`: Integration step size. Defaults to None, since step size is
automatically determined by the default `PIDController`.
Expand Down Expand Up @@ -133,9 +133,11 @@ def make_hpc_step(
**Main arguments:**
- `network`: List of callable network layers.
- `optim`: Optax optimiser, e.g. `optax.sgd()`.
- `opt_state`: State of Optax optimiser.
- `generator`: List of callable layers for the generative network.
- `amortiser`: List of callable layers for network amortising the inference
of the generative model.
- `optims`: Optax optimisers (e.g. `optax.sgd()`), one for each model.
- `opt_states`: State of Optax optimisers, one for each model.
- `output`: Observation of the generator, input to the amortiser.
- `input`: Prior of the generator, target for the amortiser.
Expand Down
Empty file added jpc/core/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions jpc/_init.py → jpc/core/_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def init_activities_from_gaussian(
'unsupervised' the input layer is also initialised.
- `batch_size`: Dimension of data batch.
- `sigma`: Standard deviation for Gaussian to sample activities from.
Defaults to 5e-2.
**Returns:**
Expand Down

0 comments on commit 243fe37

Please sign in to comment.