From 732354e9dae50c09521e64339a1d9efded4fe633 Mon Sep 17 00:00:00 2001 From: felixzinn <151917409+felixzinn@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:53:40 +0200 Subject: [PATCH] move to penzai V2 API and independent treescope package (#17) * move to penzai V2 API and independent treescope package * [visualization] add custom treescope repr * SupportTreescope -> SupportsTreescope --- docs/building_blocks.md | 17 ++- docs/tips_and_tricks.md | 25 +++-- pyproject.toml | 3 +- src/evermore/effect.py | 3 +- src/evermore/modifier.py | 3 +- src/evermore/parameter.py | 3 +- src/evermore/pdf.py | 4 +- src/evermore/staterror.py | 3 +- src/evermore/visualization.py | 191 ++++++++++++++++++---------------- 9 files changed, 134 insertions(+), 118 deletions(-) diff --git a/docs/building_blocks.md b/docs/building_blocks.md index f98c7d0..78ab902 100644 --- a/docs/building_blocks.md +++ b/docs/building_blocks.md @@ -122,10 +122,10 @@ Correlate a Parameter ``` -:::{admonition} Inspect `evm.Parameters` with `penzai` +:::{admonition} Inspect `evm.Parameters` with `treescope` :class: tip dropdown -Inspect a (PyTree of) `evm.Parameters` with [`penzai`'s treescope](https://penzai.readthedocs.io/en/stable/notebooks/treescope_prettyprinting.html) visualization in IPython or Colab notebooks (see for more information). +Inspect a (PyTree of) `evm.Parameters` with [treescope](https://treescope.readthedocs.io/en/stable/index.html) visualization in IPython or Colab notebooks (see for more information). You can even add custom visualizers, such as: ```{code-block} python @@ -134,9 +134,8 @@ import evermore as evm tree = {"a": evm.NormalParameter(), "b": evm.NormalParameter()} -with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): - pz_tree = evm.visualization.convert_tree_to_penzai(tree) - pz.ts.display(pz_tree) +with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()): + treescope.display(tree) ``` ::: @@ -275,7 +274,7 @@ Multiple modifiers should be combined using `evm.modifier.Compose` or the `@` op import jax import jax.numpy as jnp import evermore as evm -from penzai import pz +import treescope jax.config.update("jax_enable_x64", True) @@ -293,7 +292,7 @@ modifier2 = param.scale_log(up=1.1, down=0.9) (modifier1 @ modifier2)(jnp.array([10, 20, 30])) # -> Array([10.259877, 20.500944, 30.760822], dtype=float32) -with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): - pz_tree = evm.visualization.convert_tree_to_penzai(modifier1 @ modifier2) - pz.ts.display(pz_tree) +with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()): + composition = modifier1 @ modifier2 + treescope.display(composition) ``` diff --git a/docs/tips_and_tricks.md b/docs/tips_and_tricks.md index b6577bb..eaee740 100644 --- a/docs/tips_and_tricks.md +++ b/docs/tips_and_tricks.md @@ -15,17 +15,17 @@ kernelspec: Here are some advanced tips and tricks. -(penzai-visualization)= -## penzai Visualization +(treescope-visualization)= +## treescope Visualization -evermore components can be visualized with `penzai`. Convert the corresponding PyTree using `evermore.visualization.convert_tree_to_penzai` and use it with `penzai`. In IPython notebooks you can display the tree using `penzai.ts.display`. +evermore components can be visualized with [treescope](https://treescope.readthedocs.io/en/stable/index.html). In IPython notebooks you can display the tree using `treescope.display`. ```{code-cell} ipython3 import jax import jax.numpy as jnp import evermore as evm import equinox as eqx -from penzai import pz +import treescope jax.config.update("jax_enable_x64", True) @@ -50,16 +50,14 @@ composition = evm.modifier.Compose( evm.Modifier(parameter=sigma1, effect=evm.effect.AsymmetricExponential(up=1.2, down=0.8)), ) -with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): - pz_tree = evm.visualization.convert_tree_to_penzai(composition) - pz.ts.display(pz_tree) +with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()): + treescope.display(composition) ``` You can also save the tree to an HTML file. ```{code-cell} python -with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): - pz_tree = evm.visualization.convert_tree_to_penzai(composition) - contents = pz.ts.render_to_html(pz_tree, roundtrip_mode=True) +with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()): + contents = treescope.render_to_html(composition) with open("composition.html", "w") as f: f.write(contents) @@ -112,6 +110,7 @@ You can e.g. sample the parameter values multiple times vectorized from its prio ```{code-cell} ipython3 import jax import evermore as evm +import treescope params = {"a": evm.NormalParameter(), "b": evm.NormalParameter()} @@ -121,9 +120,9 @@ rng_keys = jax.random.split(rng_key, 100) vec_sample = jax.vmap(evm.parameter.sample, in_axes=(None, 0)) -with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): - pz_tree = evm.visualization.convert_tree_to_penzai(vec_sample(params, rng_keys)) - pz.ts.display(pz_tree) +with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()): + tree = vec_sample(params, rng_keys) + treescope.display(tree) ``` Many minimizers from the JAX ecosystem are e.g. batchable (`optax`, `optimistix`), which allows you vectorize _full fits_, e.g., for embarrassingly parallel likleihood profiles. diff --git a/pyproject.toml b/pyproject.toml index 445b821..5ffccbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "jaxlib", "jaxtyping", "equinox>=0.10.6", # eqx.field + "treescope", ] [project.optional-dependencies] @@ -45,7 +46,6 @@ docs = [ "sphinx-book-theme", "sphinx-design", "sphinx-togglebutton", - "penzai", ] [project.urls] @@ -135,6 +135,7 @@ jax = "*" jaxlib = "*" jaxtyping = "*" equinox = ">=0.10.6" # eqx.field +treescope = "*" [tool.pixi.pypi-dependencies] evermore = { path = ".", editable = true } diff --git a/src/evermore/effect.py b/src/evermore/effect.py index e226cf5..7c37597 100644 --- a/src/evermore/effect.py +++ b/src/evermore/effect.py @@ -9,6 +9,7 @@ from evermore.custom_types import OffsetAndScale from evermore.parameter import Parameter +from evermore.visualization import SupportsTreescope __all__ = [ "Effect", @@ -23,7 +24,7 @@ def __dir__(): return __all__ -class Effect(eqx.Module): +class Effect(eqx.Module, SupportsTreescope): @abc.abstractmethod def __call__(self, parameter: PyTree[Parameter], hist: Array) -> OffsetAndScale: ... diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index 8411a4c..9da092b 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -15,6 +15,7 @@ from evermore.effect import DEFAULT_EFFECT from evermore.parameter import Parameter from evermore.util import tree_stack +from evermore.visualization import SupportsTreescope if TYPE_CHECKING: from evermore.effect import Effect @@ -58,7 +59,7 @@ def __matmul__(self: ModifierLike, other: ModifierLike) -> Compose: return Compose(self, other) -class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier): +class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier, SupportsTreescope): """ This serves as a base class for all modifiers. It automatically implements the __call__ method to apply the scale factors to the hist array diff --git a/src/evermore/parameter.py b/src/evermore/parameter.py index 7b508ba..79db732 100644 --- a/src/evermore/parameter.py +++ b/src/evermore/parameter.py @@ -12,6 +12,7 @@ from evermore.custom_types import PDFLike from evermore.pdf import Normal, Poisson from evermore.util import filter_tree_map +from evermore.visualization import SupportsTreescope if TYPE_CHECKING: from evermore.modifier import Modifier @@ -33,7 +34,7 @@ def __dir__(): return __all__ -class Parameter(eqx.Module): +class Parameter(eqx.Module, SupportsTreescope): """ Implementation of a general Parameter class. The class is used to define the parameters of a statistical model. Key is the value attribute, which holds the actual value of the parameter. In additon, diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index b6f2cbc..3067b7e 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -8,6 +8,8 @@ from jax.scipy.special import gammaln, xlogy from jaxtyping import Array, PRNGKeyArray +from evermore.visualization import SupportsTreescope + __all__ = [ "PDF", "Normal", @@ -19,7 +21,7 @@ def __dir__(): return __all__ -class PDF(eqx.Module): +class PDF(eqx.Module, SupportsTreescope): @abstractmethod def log_prob(self, x: Array) -> Array: ... diff --git a/src/evermore/staterror.py b/src/evermore/staterror.py index 30156d5..8c79186 100644 --- a/src/evermore/staterror.py +++ b/src/evermore/staterror.py @@ -14,6 +14,7 @@ from evermore.parameter import NormalParameter, Parameter from evermore.pdf import Poisson from evermore.util import sum_over_leaves +from evermore.visualization import SupportsTreescope __all__ = [ "StatErrors", @@ -24,7 +25,7 @@ def __dir__(): return __all__ -class StatErrors(eqx.Module): +class StatErrors(eqx.Module, SupportsTreescope): """ Create staterror (barlow-beeston) parameters. diff --git a/src/evermore/visualization.py b/src/evermore/visualization.py index ece7911..f878b75 100644 --- a/src/evermore/visualization.py +++ b/src/evermore/visualization.py @@ -1,110 +1,121 @@ from __future__ import annotations import dataclasses -import threading +from collections.abc import Callable from typing import Any -import jax.tree_util as jtu -from jaxtyping import Array, PyTree - -from evermore.custom_types import ModifierLike, PDFLike -from evermore.effect import ( - AsymmetricExponential, - Effect, - Identity, - Linear, - VerticalTemplateMorphing, -) -from evermore.modifier import ( - BooleanMask, - Compose, - Modifier, - Transform, - TransformOffset, - TransformScale, - Where, -) -from evermore.parameter import NormalParameter, Parameter -from evermore.pdf import Normal, Poisson - -__all__ = [ - "convert_tree_to_penzai", -] - - -def __dir__(): - return __all__ - - -@dataclasses.dataclass -class EvermoreClassesContext(threading.local): - cls_types: list[Any] = dataclasses.field(default_factory=list) - - -Context = EvermoreClassesContext() - - -Context.cls_types.extend( - [ - NormalParameter, - Parameter, - Identity, - Linear, - AsymmetricExponential, - VerticalTemplateMorphing, - Effect, - Modifier, - Compose, - Where, - BooleanMask, - Transform, - TransformScale, - TransformOffset, - Normal, - Poisson, - ModifierLike, - PDFLike, - ] +from treescope import ( + dataclass_util, + formatting_util, + renderers, + rendering_parts, ) -def convert_tree_to_penzai(tree: PyTree) -> PyTree: - from functools import partial +class SupportsTreescope: + def __treescope_repr__( + self, + path: str, + subtree_renderer: Callable[[Any, str | None], rendering_parts.Rendering], + ) -> rendering_parts.Rendering: + return handle_evermore_classes(self, path, subtree_renderer) + + +def handle_evermore_classes( + node: Any, + path: str | None, + subtree_renderer: renderers.TreescopeSubtreeRenderer, +) -> rendering_parts.RenderableTreePart | rendering_parts.Rendering: + """Renders evermore classes. + Taken from: https://github.com/google-deepmind/penzai/blob/b1bd577dc34f0e7b8f7fef3bbeb2cd571c2f8fcd/penzai/core/_treescope_handlers/struct_handler.py + + Args: + node: The node to render. + path: The path to the node. (Optional) + subtree_renderer: A recursive renderer for subtrees. + + Returns: + A rendering of evermore classes. + """ + + # get prefix, e.g. "Parameter(" + prefix = render_evermore_constructor(node) + + # get fields of the dataclass, e.g. value=1.0 + fields = dataclasses.fields(node) + + # get children of the tree + children = rendering_parts.build_field_children( + node, + path, + subtree_renderer, + fields_or_attribute_names=fields, + attr_style_fn=evermore_attr_style_fn_for_fields(fields), + ) + + # get colors for the background of the tree node + def _treescope_color(node) -> str: + """Returns the color of the tree node.""" + + type_string = type(node).__module__ + "." + type(node).__qualname__ + return formatting_util.color_from_string(type_string) + + background_color, background_pattern = ( + formatting_util.parse_simple_color_and_pattern_spec( + _treescope_color(node), type(node).__name__ + ) + ) - for cls in Context.cls_types: + return rendering_parts.build_foldable_tree_node_from_children( + prefix=prefix, + children=children, + suffix=")", + background_color=background_color, + background_pattern=background_pattern, + ) - def _is_evm_cls(leaf: Any, cls: Any) -> bool: - return isinstance(leaf, cls) - tree = jtu.tree_map( - partial(_convert, cls=cls), tree, is_leaf=partial(_is_evm_cls, cls=cls) - ) - return tree +def evermore_attr_style_fn_for_fields( + fields, +) -> Callable[[str], rendering_parts.RenderableTreePart]: + """Builds a function to render attributes of an evermore class. + The resulting function will render pytree node fields in a different style. + E.g. the field "value" of a Parameter class will be rendered in a different style. -def _convert(leaf: Any, cls: Any) -> Any: - from penzai.deprecated.v1 import pz + Taken from: https://github.com/google-deepmind/penzai/blob/b1bd577dc34f0e7b8f7fef3bbeb2cd571c2f8fcd/penzai/core/_treescope_handlers/struct_handler.py - if isinstance(leaf, cls) and dataclasses.is_dataclass(leaf): - fields = dataclasses.fields(leaf) + Args: + fields: The fields of the evermore class. - leaf_cls = type(leaf) - attributes: dict[str, Any] = { - "__annotations__": {field.name: field.type for field in fields} - } + Returns: + A function that takes a field name and returns a RenderableTreePart.""" + fields_by_name = {field.name: field for field in fields} - if callable(leaf_cls): - attributes["__call__"] = leaf_cls.__call__ + def attr_style_fn(field_name): + field = fields_by_name[field_name] + if field.metadata.get("pytree_node", True): + return rendering_parts.custom_style( + rendering_parts.text(field_name), + css_style="font-style: italic; color: #00255f;", + ) + return rendering_parts.text(field_name) - def _pretty(x: Any) -> Any: - if isinstance(x, Array) and x.size == 1: - return x.item() - return x + return attr_style_fn - attrs = {k: _pretty(getattr(leaf, k)) for k in attributes["__annotations__"]} - new_cls = pz.pytree_dataclass( - type(leaf_cls.__name__, (pz.Layer,), dict(attributes)) +def render_evermore_constructor(node: Any) -> rendering_parts.RenderableTreePart: + """Renders the constructor of an evermore class, with an open parenthesis. + Taken from: https://github.com/google-deepmind/penzai/blob/b1bd577dc34f0e7b8f7fef3bbeb2cd571c2f8fcd/penzai/core/_treescope_handlers/struct_handler.py + """ + if dataclass_util.init_takes_fields(type(node)): + return rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(node)), "(" ) - return new_cls(**attrs) - return leaf + + return rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(node)), + rendering_parts.roundtrip_condition( + roundtrip=rendering_parts.text(".from_attributes") + ), + )