Skip to content

Commit

Permalink
Change logo depending on light vs dark mode & update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jun 29, 2024
1 parent 334c9ae commit 2b34e0d
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 35 deletions.
2 changes: 2 additions & 0 deletions docs/_overrides/partials/logo.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
<img id="logo_light_mode" src="{{ config.theme.logo_light_mode | url }}" alt="logo">
<img id="logo_dark_mode" src="{{ config.theme.logo_dark_mode | url }}" alt="logo">
30 changes: 30 additions & 0 deletions docs/_static/custom_css.css
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,36 @@ html {
scroll-padding-top: 50px;
}

/* Hide the dark logo by default */
#logo_dark_mode {
display: none;
}

/* Show the light logo by default */
#logo_light_mode {
display: block;
}

/* Switch display property based on color scheme */
[data-md-color-scheme="default"] {
--md-footer-logo-dark-mode: none;
--md-footer-logo-light-mode: block;
}

[data-md-color-scheme="slate"] {
--md-footer-logo-dark-mode: block;
--md-footer-logo-light-mode: none;
}

/* Apply the custom variables */
#logo_light_mode {
display: var(--md-footer-logo-light-mode);
}

#logo_dark_mode {
display: var(--md-footer-logo-dark-mode);
}

/* Adjust logo size */
.md-header__button.md-logo {
margin: 0;
Expand Down
3 changes: 3 additions & 0 deletions docs/_static/logo-dark.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
Empty file added docs/advanced_usage.md
Empty file.
3 changes: 0 additions & 3 deletions docs/api/get_fc_network.md

This file was deleted.

3 changes: 3 additions & 0 deletions docs/api/make_mlp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# make_mlp

::: jpc.make_mlp
64 changes: 64 additions & 0 deletions docs/basic_usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
!!! info
JPC provides two types of API depending on the use case:
* a simple, basic API that allows to train and test models with predictive
coding with a few lines of code
* a more advanced and flexible API allowing for

Describe purposes/use cases of both basic and advanced.

# Basic usage

JPC provides a single convenience function `jpc.make_pc_step()` to train
predictive coding networks (PCNs) on classification and generation tasks, in a
supervised as well as unsupervised manner.
```py
import jpc

relu_net = jpc.get_fc_network(key, [10, 100, 100, 10], "relu")
result = jpc.make_pc_step(
model=relu_net,
optim=optim,
opt_state=opt_state,
y=y,
x=x
)
```
At a minimum, `jpc.make_pc_step()` takes a model, an optax optimiser and its
state, and an output target. Under the hood, `jpc.make_pc_step()` uses diffrax
to solve the activity (inference) dynamics of PC. The arguments can be changed
```py
import jpc

result = jpc.make_pc_step(
model=network,
optim=optim,
opt_state=opt_state,
y=y,
x=x,
solver=other_solver,
dt=1e-1,
)
```
Moreover,

JPC provides a similar function for training a hybrid PCN
```py
import jax
import jax.numpy as jnp
from equinox import nn as nn

# some data
x = jnp.array([1., 1., 1.])
y = -x

# network
key = jax.random.key(0)
_, *subkeys = jax.random.split(key)
network = [nn.Sequential(
[
nn.Linear(3, 100, key=subkeys[0]),
nn.Lambda(jax.nn.relu)],
),
nn.Linear(100, 3, key=subkeys[1]),
]
```
Empty file added docs/extending_jpc.md
Empty file.
68 changes: 41 additions & 27 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# Getting started

JPC is a [JAX](https://github.com/google/jax) library for predictive
coding networks (PCNs). It is built on top of two main libraries:
JPC is a [JAX](https://github.com/google/jax) library to train neural networks
with predictive coding. It is built on top of three main libraries:

* [Equinox](https://github.com/patrick-kidger/equinox), to define neural
networks with PyTorch-like syntax, and
* [Diffrax](https://github.com/patrick-kidger/diffrax), to solve the PC
activity (inference) dynamics.
* [Optax](https://github.com/google-deepmind/optax), for parameter optimisation.

JPC provides a simple but flexible API for research of PCNs compatible with
useful JAX transforms such as `vmap` and `jit`.

## Installation
## 💻 Installation

```
pip install jpc
Expand All @@ -22,10 +23,9 @@ Requires Python 3.9+, JAX 0.4.23+, [Equinox](https://github.com/patrick-kidger/e
[Optax](https://github.com/google-deepmind/optax) 0.2.2+, and
[Jaxtyping](https://github.com/patrick-kidger/jaxtyping) 0.2.24+.

## Quick example
## ⚡️ Quick example

Given a neural network with callable layers, for example defined with
[Equinox](https://github.com/patrick-kidger/equinox)
Given a neural network with callable layers
```py
import jax
import jax.numpy as jnp
Expand All @@ -38,35 +38,49 @@ y = -x
# network
key = jax.random.key(0)
_, *subkeys = jax.random.split(key)
network = [
nn.Sequential(
[
nn.Linear(3, 100, key=subkeys[0]),
nn.Lambda(jax.nn.relu)
],
network = [nn.Sequential(
[
nn.Linear(3, 100, key=subkeys[0]),
nn.Lambda(jax.nn.relu)],
),
nn.Linear(100, 3, key=subkeys[1]),
]
```
We can train it with predictive coding in a few lines of code
we can perform a PC parameter update with a single function call
```py
import jpc
import optax
import equinox as eqx

# initialise layer activities with a feedforward pass
activities = jpc.init_activities_with_ffwd(network, x)
# optimiser
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(network, eqx.is_array))

# run the inference dynamics to equilibrium
equilib_activities = jpc.solve_pc_activities(network, activities, y, x)

# compute the PC parameter gradients
pc_param_grads = jpc.compute_pc_param_grads(
network,
equilib_activities,
y,
x
# PC parameter update
result = jpc.make_pc_step(
model=network,
optim=optim,
opt_state=opt_state,
y=y,
x=x
)

```
The gradients can then be fed to your favourite optimiser (e.g. gradient
descent) to update the network parameters.

## Citation
## 📄 Citation

If you found this library useful in your work, please cite (arXiv link):

```bibtex
@article{innocenti2024jpc,
title={JPC: Predictive Coding Networks in JAX},
author={Innocenti, Francesco and Kinghorn, Paul and Singh, Ryan and
De Llanza Varona, Miguel and Buckley, Christopher},
journal={arXiv preprint},
year={2024}
}
```
Also consider starring the project [on GitHub](https://github.com/thebuckleylab/jpc)! ⭐️

## ⏭️ Next steps

2 changes: 1 addition & 1 deletion jpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
compute_pc_param_grads as compute_pc_param_grads
)
from ._utils import (
get_fc_network as get_fc_network,
make_mlp as make_mlp,
compute_accuracy as compute_accuracy,
get_t_max as get_t_max,
compute_infer_energies as compute_infer_energies
Expand Down
11 changes: 7 additions & 4 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,24 @@ theme:
# Light mode / dark mode
# We deliberately don't automatically use `media` to check a user's preferences. We default to light mode as
# (a) it looks more professional, and (b) is more obvious about the fact that it offers a (dark mode) toggle.
- scheme: default
- media: "(prefers-color-scheme: default)"
scheme: default
primary: white
accent: orange
toggle:
icon: material/weather-night
name: Switch to dark mode
- scheme: slate
- media: "(prefers-color-scheme: slate)"
scheme: slate
primary: black
accent: orange
toggle:
icon: material/weather-sunny
name: Switch to light mode
icon:
repo: fontawesome/brands/github # GitHub logo in top right
logo: "_static/logo.svg" # jpc logo in top left
logo_light_mode: "_static/logo-light.svg" # jpc logo in top left
logo_dark_mode: "_static/logo-dark.svg" # jpc logo in top left
favicon: "_static/favicon.png"
custom_dir: "docs/_overrides" # Overriding part of the HTML

Expand Down Expand Up @@ -106,7 +109,7 @@ nav:
- 🌱 Basic API:
- 'api/Training.md'
- 'api/Testing.md'
- 'api/get_fc_network.md'
- 'api/make_mlp.md'
- 🚀 Advanced API:
- 'api/Initialisation.md'
- 'api/pc_energy_fn.md'
Expand Down

0 comments on commit 2b34e0d

Please sign in to comment.