Skip to content

Commit

Permalink
move to penzai V2 API and independent treescope package
Browse files Browse the repository at this point in the history
  • Loading branch information
felixzinn committed Aug 6, 2024
1 parent b58e088 commit bc89312
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 26 deletions.
18 changes: 9 additions & 9 deletions docs/building_blocks.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ Correlate a Parameter
```


:::{admonition} Inspect `evm.Parameters` with `penzai`
:::{admonition} Inspect `evm.Parameters` with `treescope`
:class: tip dropdown

Inspect a (PyTree of) `evm.Parameters` with [`penzai`'s treescope](https://penzai.readthedocs.io/en/stable/notebooks/treescope_prettyprinting.html) visualization in IPython or Colab notebooks (see <project:#penzai-visualization> for more information).
Inspect a (PyTree of) `evm.Parameters` with [treescope](https://treescope.readthedocs.io/en/stable/index.html) visualization in IPython or Colab notebooks (see <project:#treescope-visualization> for more information).
You can even add custom visualizers, such as:

```{code-block} python
Expand All @@ -134,9 +134,9 @@ import evermore as evm
tree = {"a": evm.NormalParameter(), "b": evm.NormalParameter()}
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(tree)
pz.ts.display(pz_tree)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
tree = evm.visualization.convert_tree_to_treescope(tree)
treescope.display(tree)
```
:::

Expand Down Expand Up @@ -275,7 +275,7 @@ Multiple modifiers should be combined using `evm.modifier.Compose` or the `@` op
import jax
import jax.numpy as jnp
import evermore as evm
from penzai import pz
import treescope
jax.config.update("jax_enable_x64", True)
Expand All @@ -293,7 +293,7 @@ modifier2 = param.scale_log(up=1.1, down=0.9)
(modifier1 @ modifier2)(jnp.array([10, 20, 30]))
# -> Array([10.259877, 20.500944, 30.760822], dtype=float32)
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(modifier1 @ modifier2)
pz.ts.display(pz_tree)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
tree = evm.visualization.convert_tree_to_treescope(modifier1 @ modifier2)
treescope.display(tree)
```
26 changes: 13 additions & 13 deletions docs/tips_and_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ kernelspec:

Here are some advanced tips and tricks.

(penzai-visualization)=
## penzai Visualization
(treescope-visualization)=
## treescope Visualization

evermore components can be visualized with `penzai`. Convert the corresponding PyTree using `evermore.visualization.convert_tree_to_penzai` and use it with `penzai`. In IPython notebooks you can display the tree using `penzai.ts.display`.
evermore components can be visualized with `treescope`. Convert the corresponding PyTree using `evermore.visualization.convert_tree_to_treescope` and use it with `treescope`. In IPython notebooks you can display the tree using `treescope.display`.

```{code-cell} ipython3
import jax
import jax.numpy as jnp
import evermore as evm
import equinox as eqx
from penzai import pz
import treescope
jax.config.update("jax_enable_x64", True)
Expand All @@ -50,16 +50,16 @@ composition = evm.modifier.Compose(
evm.Modifier(parameter=sigma1, effect=evm.effect.AsymmetricExponential(up=1.2, down=0.8)),
)
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(composition)
pz.ts.display(pz_tree)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
tree = evm.visualization.convert_tree_to_treescope(composition)
treescope.display(tree)
```

You can also save the tree to an HTML file.
```{code-cell} python
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(composition)
contents = pz.ts.render_to_html(pz_tree, roundtrip_mode=True)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
tree = evm.visualization.convert_tree_to_treescope(composition)
contents = treescope.render_to_html(tree)
with open("composition.html", "w") as f:
f.write(contents)
Expand Down Expand Up @@ -121,9 +121,9 @@ rng_keys = jax.random.split(rng_key, 100)
vec_sample = jax.vmap(evm.parameter.sample, in_axes=(None, 0))
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(vec_sample(params, rng_keys))
pz.ts.display(pz_tree)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
tree = evm.visualization.convert_tree_to_treescope(vec_sample(params, rng_keys))
treescope.display(tree)
```

Many minimizers from the JAX ecosystem are e.g. batchable (`optax`, `optimistix`), which allows you vectorize _full fits_, e.g., for embarrassingly parallel likleihood profiles.
Expand Down
3 changes: 2 additions & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ sphinx-togglebutton = "*"
myst-nb = "*"

[pypi-dependencies]
penzai = "*"
treescope = "*"
penzai = ">=0.2.0"

[tasks]
postinstall = "pip install -e '.[dev]' && pip install pre-commit && pre-commit install"
Expand Down
6 changes: 3 additions & 3 deletions src/evermore/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from evermore.pdf import Normal, Poisson

__all__ = [
"convert_tree_to_penzai",
"convert_tree_to_treescope",
]


Expand Down Expand Up @@ -68,7 +68,7 @@ class EvermoreClassesContext(threading.local):
)


def convert_tree_to_penzai(tree: PyTree) -> PyTree:
def convert_tree_to_treescope(tree: PyTree) -> PyTree:
from functools import partial

for cls in Context.cls_types:
Expand Down Expand Up @@ -104,7 +104,7 @@ def _pretty(x: Any) -> Any:
attrs = {k: _pretty(getattr(leaf, k)) for k in attributes["__annotations__"]}

new_cls = pz.pytree_dataclass(
type(leaf_cls.__name__, (pz.Layer,), dict(attributes))
type(leaf_cls.__name__, (pz.nn.Layer,), dict(attributes))
)
return new_cls(**attrs)
return leaf

0 comments on commit bc89312

Please sign in to comment.