Skip to content

Commit

Permalink
move to penzai V2 API and independent treescope package (#17)
Browse files Browse the repository at this point in the history
* move to penzai V2 API and independent treescope package

* [visualization] add custom treescope repr

* SupportTreescope -> SupportsTreescope
  • Loading branch information
felixzinn authored Aug 14, 2024
1 parent 14f9cb2 commit 732354e
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 118 deletions.
17 changes: 8 additions & 9 deletions docs/building_blocks.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ Correlate a Parameter
```


:::{admonition} Inspect `evm.Parameters` with `penzai`
:::{admonition} Inspect `evm.Parameters` with `treescope`
:class: tip dropdown

Inspect a (PyTree of) `evm.Parameters` with [`penzai`'s treescope](https://penzai.readthedocs.io/en/stable/notebooks/treescope_prettyprinting.html) visualization in IPython or Colab notebooks (see <project:#penzai-visualization> for more information).
Inspect a (PyTree of) `evm.Parameters` with [treescope](https://treescope.readthedocs.io/en/stable/index.html) visualization in IPython or Colab notebooks (see <project:#treescope-visualization> for more information).
You can even add custom visualizers, such as:

```{code-block} python
Expand All @@ -134,9 +134,8 @@ import evermore as evm
tree = {"a": evm.NormalParameter(), "b": evm.NormalParameter()}
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(tree)
pz.ts.display(pz_tree)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
treescope.display(tree)
```
:::

Expand Down Expand Up @@ -275,7 +274,7 @@ Multiple modifiers should be combined using `evm.modifier.Compose` or the `@` op
import jax
import jax.numpy as jnp
import evermore as evm
from penzai import pz
import treescope
jax.config.update("jax_enable_x64", True)
Expand All @@ -293,7 +292,7 @@ modifier2 = param.scale_log(up=1.1, down=0.9)
(modifier1 @ modifier2)(jnp.array([10, 20, 30]))
# -> Array([10.259877, 20.500944, 30.760822], dtype=float32)
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(modifier1 @ modifier2)
pz.ts.display(pz_tree)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
composition = modifier1 @ modifier2
treescope.display(composition)
```
25 changes: 12 additions & 13 deletions docs/tips_and_tricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ kernelspec:

Here are some advanced tips and tricks.

(penzai-visualization)=
## penzai Visualization
(treescope-visualization)=
## treescope Visualization

evermore components can be visualized with `penzai`. Convert the corresponding PyTree using `evermore.visualization.convert_tree_to_penzai` and use it with `penzai`. In IPython notebooks you can display the tree using `penzai.ts.display`.
evermore components can be visualized with [treescope](https://treescope.readthedocs.io/en/stable/index.html). In IPython notebooks you can display the tree using `treescope.display`.

```{code-cell} ipython3
import jax
import jax.numpy as jnp
import evermore as evm
import equinox as eqx
from penzai import pz
import treescope
jax.config.update("jax_enable_x64", True)
Expand All @@ -50,16 +50,14 @@ composition = evm.modifier.Compose(
evm.Modifier(parameter=sigma1, effect=evm.effect.AsymmetricExponential(up=1.2, down=0.8)),
)
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(composition)
pz.ts.display(pz_tree)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
treescope.display(composition)
```

You can also save the tree to an HTML file.
```{code-cell} python
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(composition)
contents = pz.ts.render_to_html(pz_tree, roundtrip_mode=True)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
contents = treescope.render_to_html(composition)
with open("composition.html", "w") as f:
f.write(contents)
Expand Down Expand Up @@ -112,6 +110,7 @@ You can e.g. sample the parameter values multiple times vectorized from its prio
```{code-cell} ipython3
import jax
import evermore as evm
import treescope
params = {"a": evm.NormalParameter(), "b": evm.NormalParameter()}
Expand All @@ -121,9 +120,9 @@ rng_keys = jax.random.split(rng_key, 100)
vec_sample = jax.vmap(evm.parameter.sample, in_axes=(None, 0))
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
pz_tree = evm.visualization.convert_tree_to_penzai(vec_sample(params, rng_keys))
pz.ts.display(pz_tree)
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
tree = vec_sample(params, rng_keys)
treescope.display(tree)
```

Many minimizers from the JAX ecosystem are e.g. batchable (`optax`, `optimistix`), which allows you vectorize _full fits_, e.g., for embarrassingly parallel likleihood profiles.
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"jaxlib",
"jaxtyping",
"equinox>=0.10.6", # eqx.field
"treescope",
]

[project.optional-dependencies]
Expand All @@ -45,7 +46,6 @@ docs = [
"sphinx-book-theme",
"sphinx-design",
"sphinx-togglebutton",
"penzai",
]

[project.urls]
Expand Down Expand Up @@ -135,6 +135,7 @@ jax = "*"
jaxlib = "*"
jaxtyping = "*"
equinox = ">=0.10.6" # eqx.field
treescope = "*"

[tool.pixi.pypi-dependencies]
evermore = { path = ".", editable = true }
Expand Down
3 changes: 2 additions & 1 deletion src/evermore/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from evermore.custom_types import OffsetAndScale
from evermore.parameter import Parameter
from evermore.visualization import SupportsTreescope

__all__ = [
"Effect",
Expand All @@ -23,7 +24,7 @@ def __dir__():
return __all__


class Effect(eqx.Module):
class Effect(eqx.Module, SupportsTreescope):
@abc.abstractmethod
def __call__(self, parameter: PyTree[Parameter], hist: Array) -> OffsetAndScale: ...

Expand Down
3 changes: 2 additions & 1 deletion src/evermore/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from evermore.effect import DEFAULT_EFFECT
from evermore.parameter import Parameter
from evermore.util import tree_stack
from evermore.visualization import SupportsTreescope

if TYPE_CHECKING:
from evermore.effect import Effect
Expand Down Expand Up @@ -58,7 +59,7 @@ def __matmul__(self: ModifierLike, other: ModifierLike) -> Compose:
return Compose(self, other)


class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier):
class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier, SupportsTreescope):
"""
This serves as a base class for all modifiers.
It automatically implements the __call__ method to apply the scale factors to the hist array
Expand Down
3 changes: 2 additions & 1 deletion src/evermore/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from evermore.custom_types import PDFLike
from evermore.pdf import Normal, Poisson
from evermore.util import filter_tree_map
from evermore.visualization import SupportsTreescope

if TYPE_CHECKING:
from evermore.modifier import Modifier
Expand All @@ -33,7 +34,7 @@ def __dir__():
return __all__


class Parameter(eqx.Module):
class Parameter(eqx.Module, SupportsTreescope):
"""
Implementation of a general Parameter class. The class is used to define the parameters of a statistical model.
Key is the value attribute, which holds the actual value of the parameter. In additon,
Expand Down
4 changes: 3 additions & 1 deletion src/evermore/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from jax.scipy.special import gammaln, xlogy
from jaxtyping import Array, PRNGKeyArray

from evermore.visualization import SupportsTreescope

__all__ = [
"PDF",
"Normal",
Expand All @@ -19,7 +21,7 @@ def __dir__():
return __all__


class PDF(eqx.Module):
class PDF(eqx.Module, SupportsTreescope):
@abstractmethod
def log_prob(self, x: Array) -> Array: ...

Expand Down
3 changes: 2 additions & 1 deletion src/evermore/staterror.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from evermore.parameter import NormalParameter, Parameter
from evermore.pdf import Poisson
from evermore.util import sum_over_leaves
from evermore.visualization import SupportsTreescope

__all__ = [
"StatErrors",
Expand All @@ -24,7 +25,7 @@ def __dir__():
return __all__


class StatErrors(eqx.Module):
class StatErrors(eqx.Module, SupportsTreescope):
"""
Create staterror (barlow-beeston) parameters.
Expand Down
Loading

0 comments on commit 732354e

Please sign in to comment.