From 5aff06660664c6429ad9015feae97312b1800373 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 6 Mar 2024 13:11:59 +0100 Subject: [PATCH 01/22] start of restructure: be more close to equinox' philosophy --- docs/conf.py | 2 + examples/dnn_weights_constraint.py | 36 ++ examples/grad_nll.py | 31 -- examples/model.py | 131 +++---- examples/nll_fit.py | 18 - examples/nll_profiling.py | 67 ---- examples/toy_generation.py | 52 +-- pyproject.toml | 3 +- src/evermore/__init__.py | 24 +- src/evermore/custom_types.py | 5 +- src/evermore/effect.py | 44 ++- src/evermore/ipy_util.py | 50 --- src/evermore/likelihood.py | 128 ------- src/evermore/loss.py | 74 ++++ src/evermore/model.py | 205 ----------- src/evermore/modifier.py | 545 ++++++++++++++--------------- src/evermore/optimizer.py | 99 ------ src/evermore/parameter.py | 58 ++- src/evermore/pdf.py | 80 ++++- src/evermore/sample.py | 36 ++ src/evermore/util.py | 268 ++------------ tests/test_optimizer.py | 31 -- tests/test_parameter.py | 27 +- tests/test_util.py | 41 +-- 24 files changed, 690 insertions(+), 1365 deletions(-) create mode 100644 examples/dnn_weights_constraint.py delete mode 100644 src/evermore/ipy_util.py delete mode 100644 src/evermore/likelihood.py create mode 100644 src/evermore/loss.py delete mode 100644 src/evermore/model.py delete mode 100644 src/evermore/optimizer.py create mode 100644 src/evermore/sample.py delete mode 100644 tests/test_optimizer.py diff --git a/docs/conf.py b/docs/conf.py index 7558ac7..0dcab2f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,6 +38,8 @@ "use_edit_page_button": True, } html_context = {"default_mode": "light"} +html_logo = "../assets/favicon.png" +html_favicon = "../assets/favicon.png" extensions = [ "sphinx.ext.autodoc", diff --git a/examples/dnn_weights_constraint.py b/examples/dnn_weights_constraint.py new file mode 100644 index 0000000..f18f145 --- /dev/null +++ b/examples/dnn_weights_constraint.py @@ -0,0 +1,36 @@ +import equinox as eqx +import jax +import jax.numpy as jnp + +import evermore as evm + + +class LinearConstrained(eqx.Module): + weights: evm.Parameter + biases: jax.Array + + def __init__(self, in_size, out_size, key): + self.biases = jax.random.normal(key, (out_size,)) + self.weights = evm.Parameter(value=jax.random.normal(key, (out_size, in_size))) + self.weights.constraints.add(evm.pdf.Gauss(mean=0.0, width=0.5)) + + def __call__(self, x: jax.Array): + return self.weights.value @ x + self.biases + + +@eqx.filter_jit +def loss_fn(model, x, y): + pred_y = jax.vmap(model)(x) + mse = jax.numpy.mean((y - pred_y) ** 2) + constraints = evm.loss.get_param_constraints(model) + # sum them all up for each weight + constraints = jax.tree_util.tree_map(jnp.sum, constraints) + return mse + evm.util.sum_leaves(constraints) + + +batch_size, in_size, out_size = 32, 2, 3 +model = LinearConstrained(in_size, out_size, key=jax.random.PRNGKey(0)) +x = jax.numpy.zeros((batch_size, in_size)) +y = jax.numpy.zeros((batch_size, out_size)) +loss_val = loss_fn(model, x, y) +grads = eqx.filter_grad(loss_fn)(model, x, y) diff --git a/examples/grad_nll.py b/examples/grad_nll.py index c5a3662..e69de29 100644 --- a/examples/grad_nll.py +++ b/examples/grad_nll.py @@ -1,31 +0,0 @@ -from __future__ import annotations - -import equinox as eqx -from jax import config -from model import init_values, model, observation, optimizer - -import evermore as evm - -config.update("jax_enable_x64", True) - -# create negative log likelihood -nll = evm.likelihood.NLL(model=model, observation=observation) - -# fit -params, state = optimizer.fit(fun=nll, init_values=init_values) - -# gradients of nll of fitted model -fast_grad_nll = eqx.filter_jit(eqx.filter_grad(nll)) -grads = fast_grad_nll(params) -# gradients of nll of fitted model only wrt to `mu` -# basically: pass the parameters dict of which you want the gradients -params_ = {k: v for k, v in params.items() if k == "mu"} -grad_mu = fast_grad_nll(params_) - -# hessian + cov_matrix of nll of fitted model -hessian = eqx.filter_jit(evm.likelihood.Hessian(model=model, observation=observation))() - -# covariance matrix of fitted model -covmatrix = eqx.filter_jit( - evm.likelihood.CovMatrix(model=model, observation=observation) -)() diff --git a/examples/model.py b/examples/model.py index f7756b7..be416a0 100644 --- a/examples/model.py +++ b/examples/model.py @@ -1,94 +1,79 @@ from __future__ import annotations +from typing import Any + +import equinox as eqx +import jax import jax.numpy as jnp import evermore as evm -class SPlusBModel(evm.Model): - def __call__(self, processes: dict, parameters: dict) -> evm.Result: - res = evm.Result() +class SPlusBModel(eqx.Module): + mu: evm.Parameter + norm1: evm.Parameter + norm2: evm.Parameter + shape1: evm.Parameter - mu_modifier = evm.modifier( - name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained() - ) - res.add( - process="signal", - expectation=mu_modifier(processes[("signal", "nominal")]), - ) + def __init__(self) -> None: + self.mu = evm.Parameter(value=jnp.array([1.0])) + self.norm1 = evm.Parameter() + self.norm2 = evm.Parameter() + self.shape1 = evm.Parameter() - bkg1_modifier = evm.compose( - evm.modifier( - name="lnN1", - parameter=parameters["norm1"], - effect=evm.effect.lnN((0.9, 1.1)), - ), - evm.modifier( - name="shape1_bkg1", - parameter=parameters["shape1"], - effect=evm.effect.shape( - up=processes[("background1", "shape_up")], - down=processes[("background1", "shape_down")], - ), - ), - ) - res.add( - process="background1", - expectation=bkg1_modifier(processes[("background1", "nominal")]), - ) + def __call__(self, hists: dict[Any, jax.Array]) -> dict[str, jax.Array]: + expectations = {} - bkg2_modifier = evm.compose( - evm.modifier( - name="lnN2", - parameter=parameters["norm2"], - effect=evm.effect.lnN((0.95, 1.05)), - ), - evm.modifier( - name="shape1_bkg2", - parameter=parameters["shape1"], - effect=evm.effect.shape( - up=processes[("background2", "shape_up")], - down=processes[("background2", "shape_down")], - ), - ), + # signal process + sig_mod = self.mu.unconstrained() + expectations["signal"] = sig_mod(hists[("signal", "nominal")]) + + # bkg1 process + bkg1_mod = self.norm1.lnN(width=jnp.array([0.9, 1.1])) @ self.shape1.shape( + up=hists[("bkg1", "shape_up")], + down=hists[("bkg1", "shape_down")], ) - res.add( - process="background2", - expectation=bkg2_modifier(processes[("background2", "nominal")]), + expectations["bkg1"] = bkg1_mod(hists[("bkg1", "nominal")]) + + # bkg2 process + bkg2_mod = self.norm2.lnN(width=jnp.array([0.95, 1.05])) @ self.shape1.shape( + up=hists[("bkg2", "shape_up")], + down=hists[("bkg2", "shape_down")], ) - return res + expectations["bkg2"] = bkg2_mod(hists[("bkg2", "nominal")]) + # return the modified expectations + return expectations -def create_model(): - processes = { - ("signal", "nominal"): jnp.array([3]), - ("background1", "nominal"): jnp.array([10]), - ("background2", "nominal"): jnp.array([20]), - ("background1", "shape_up"): jnp.array([12]), - ("background1", "shape_down"): jnp.array([8]), - ("background2", "shape_up"): jnp.array([23]), - ("background2", "shape_down"): jnp.array([19]), - } - parameters = { - "mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), - "norm1": evm.Parameter(value=jnp.array([0.0])), - "norm2": evm.Parameter(value=jnp.array([0.0])), - "shape1": evm.Parameter(value=jnp.array([0.0])), - } - # return model - return SPlusBModel(processes=processes, parameters=parameters) +model = SPlusBModel() -model = create_model() +hists = { + ("signal", "nominal"): jnp.array([3]), + ("bkg1", "nominal"): jnp.array([10]), + ("bkg2", "nominal"): jnp.array([20]), + ("bkg1", "shape_up"): jnp.array([12]), + ("bkg1", "shape_down"): jnp.array([8]), + ("bkg2", "shape_up"): jnp.array([23]), + ("bkg2", "shape_down"): jnp.array([19]), +} -init_values = model.parameter_values observation = jnp.array([37]) -asimov = model.evaluate().expectation() + +nll = evm.loss.PoissonNLL() + + +@eqx.filter_jit +def loss(model, hists, observation): + expectations = model(hists) + constraints = evm.loss.get_param_constraints(model) + return nll( + expectation=evm.util.sum_leaves(expectations), + observation=observation, + constraint=evm.util.sum_leaves(constraints), + ) -# create optimizer (from `jaxopt`) -optimizer = evm.optimizer.JaxOptimizer.make( - name="LBFGS", - settings={"maxiter": 5, "jit": True, "unroll": True}, -) +loss_val = loss(model, hists, observation) +grads = eqx.filter_grad(loss)(model, hists, observation) diff --git a/examples/nll_fit.py b/examples/nll_fit.py index c983205..e69de29 100644 --- a/examples/nll_fit.py +++ b/examples/nll_fit.py @@ -1,18 +0,0 @@ -from __future__ import annotations - -from jax import config -from model import init_values, model, observation, optimizer - -from evermore.likelihood import NLL - -config.update("jax_enable_x64", True) - - -# create negative log likelihood -nll = NLL(model=model, observation=observation) - -# fit -values, state = optimizer.fit(fun=nll, init_values=init_values) - -# update model with fitted values -fitted_model = model.update(values=values) diff --git a/examples/nll_profiling.py b/examples/nll_profiling.py index 45bf296..e69de29 100644 --- a/examples/nll_profiling.py +++ b/examples/nll_profiling.py @@ -1,67 +0,0 @@ -from __future__ import annotations - -from functools import partial - -import equinox as eqx -import jax -import jax.numpy as jnp -from jax import config -from model import asimov, model, optimizer - -from evermore import Model -from evermore.likelihood import NLL -from evermore.optimizer import JaxOptimizer - -config.update("jax_enable_x64", True) - - -def nll_profiling( - value_name: str, - scan_points: jax.Array, - model: Model, - observation: jax.Array, - optimizer: JaxOptimizer, - fit: bool, -) -> jax.Array: - # define single fit for a fixed parameter of interest (poi) - @partial(jax.jit, static_argnames=("value_name", "optimizer", "fit")) - def fixed_poi_fit( - value_name: str, - scan_point: jax.Array, - model: Model, - observation: jax.Array, - optimizer: JaxOptimizer, - fit: bool, - ) -> jax.Array: - # fix theta into the model - model = model.update(values={value_name: scan_point}) - init_values = model.parameter_values - init_values.pop(value_name, 1) - # minimize - nll = eqx.filter_jit(NLL(model=model, observation=observation)) - if fit: - values, _ = optimizer.fit(fun=nll, init_values=init_values) - else: - values = model.parameter_values - return nll(values=values) - - # vectorise for multiple fixed values (scan points) - fixed_poi_fit_vec = jax.vmap( - fixed_poi_fit, in_axes=(None, 0, None, None, None, None) - ) - return fixed_poi_fit_vec( - value_name, scan_points, model, observation, optimizer, fit - ) - - -# profile the NLL around starting point of `0` -scan_points = jnp.r_[-1.9:2.0:0.1] - -profile_postfit = nll_profiling( - value_name="norm1", - scan_points=scan_points, - model=model, - observation=asimov, - optimizer=optimizer, - fit=True, -) diff --git a/examples/toy_generation.py b/examples/toy_generation.py index 06045ef..e7038d9 100644 --- a/examples/toy_generation.py +++ b/examples/toy_generation.py @@ -1,31 +1,43 @@ -from __future__ import annotations +from typing import Any import equinox as eqx import jax -from jax import config -from model import init_values, model, observation, optimizer +from jaxtyping import Array, PRNGKeyArray +from model import hists, model, observation -from evermore.likelihood import NLL, SampleToy +import evermore as evm -config.update("jax_enable_x64", True) +key = jax.random.PRNGKey(0) +# generate a new model with sampled parameters according to their constraint pdfs +toymodel = evm.sample.toy_module(model, key) -# create negative log likelihood -nll = NLL(model=model, observation=observation) -# fit -values, state = optimizer.fit(fun=nll, init_values=init_values) +# generate new expectation based on the toy model +def toy_expectation( + key: PRNGKeyArray, + module: eqx.Module, + hists: dict[Any, Array], +) -> Array: + toymodel = evm.sample.toy_module(model, key) + expectations = toymodel(hists) + return evm.util.sum_leaves(expectations) -# create sampling method -sample_toy = SampleToy(model=model, observation=observation) -# vectorise and jit -sample_toys = eqx.filter_vmap(in_axes=(None, 0))(eqx.filter_jit(sample_toy)) -sample_toy(values, jax.random.PRNGKey(1234)) +expectation = toy_expectation(key, model, hists) -# sample 10 toys based on fitted parameters -keys = jax.random.split(jax.random.PRNGKey(1234), num=10) -# postfit toys -toys_postfit = sample_toys(values, keys) -# prefit toys -toys_prefit = sample_toys(init_values, keys) + +# generate a new expectations vectorized over many keys +keys = jax.random.split(key, 1000) + +# vectorized toy expectation +toy_expectation_vec = jax.vmap(toy_expectation, in_axes=(0, None, None)) +expectations = toy_expectation_vec(keys, model, hists) + + +# just sample observations with poisson +poisson_obs = evm.pdf.Poisson(observation) +sampled_observation = poisson_obs.sample(key) + +# vectorized sampling +sampled_observations = jax.vmap(poisson_obs.sample)(keys) diff --git a/pyproject.toml b/pyproject.toml index e4ac42b..34223e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,12 +27,11 @@ classifiers = [ dynamic = ["version"] # version is set in src/evermore/__init__.py dependencies = [ "equinox>=0.10.6", # eqx.field - "jaxopt >=0.6", # jaxopt.LGBFGS ] [project.optional-dependencies] test = ["pytest >=6", "pytest-cov >=3"] -dev = ["pytest >=6", "pytest-cov >=3"] +dev = ["pytest >=6", "pytest-cov >=3", "optax", "jaxopt >=0.6"] docs = [ "sphinx>=7.0", "myst_parser>=0.13", diff --git a/src/evermore/__init__.py b/src/evermore/__init__.py index eb11bd6..f9a52d6 100644 --- a/src/evermore/__init__.py +++ b/src/evermore/__init__.py @@ -18,18 +18,15 @@ __all__ = [ "__version__", "effect", - "ipy_util", - "likelihood", - "optimizer", + "loss", "pdf", "util", + "sample", # explicitely expose some classes - "Model", - "Result", "Parameter", "modifier", - "staterror", - "autostaterrors", + # "staterror", + # "autostaterrors", "compose", ] @@ -40,17 +37,18 @@ def __dir__(): from evermore import ( # noqa: E402 effect, - ipy_util, - likelihood, - optimizer, + loss, pdf, + sample, util, ) -from evermore.model import Model, Result # noqa: E402 + +# from evermore.model import Model, Result from evermore.modifier import ( # noqa: E402 - autostaterrors, + # autostaterrors, compose, modifier, - staterror, ) + +# staterror, from evermore.parameter import Parameter # noqa: E402 diff --git a/src/evermore/custom_types.py b/src/evermore/custom_types.py index 28bd629..b901250 100644 --- a/src/evermore/custom_types.py +++ b/src/evermore/custom_types.py @@ -1,10 +1,9 @@ from collections.abc import Callable from typing import Any -import jax +import jaxtyping -ArrayLike = jax.typing.ArrayLike -AddOrMul = Callable[[ArrayLike, ArrayLike], jax.Array] +AddOrMul = Callable[[jaxtyping.ArrayLike, jaxtyping.ArrayLike], jaxtyping.Array] class Sentinel: diff --git a/src/evermore/effect.py b/src/evermore/effect.py index c300bd4..9fb4642 100644 --- a/src/evermore/effect.py +++ b/src/evermore/effect.py @@ -3,10 +3,10 @@ from typing import TYPE_CHECKING, ClassVar import equinox as eqx -import jax import jax.numpy as jnp +from jaxtyping import Array, Float -from evermore.custom_types import AddOrMul, ArrayLike +from evermore.custom_types import AddOrMul from evermore.parameter import Parameter from evermore.pdf import Flat, Gauss, HashablePDF, Poisson from evermore.util import as1darray @@ -40,7 +40,7 @@ def constraint(self) -> HashablePDF: ... @abc.abstractmethod - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: ... @@ -51,7 +51,7 @@ class unconstrained(Effect): def constraint(self) -> HashablePDF: return Flat() - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: return parameter.value @@ -59,18 +59,18 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class gauss(Effect): - width: ArrayLike = eqx.field(static=True, converter=as1darray) + width: Array = eqx.field(static=True, converter=as1darray) apply_op: ClassVar[AddOrMul] = operator.mul - def __init__(self, width: ArrayLike) -> None: + def __init__(self, width: Array) -> None: self.width = width @property def constraint(self) -> HashablePDF: return Gauss(mean=0.0, width=1.0) - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: """ Implementation with (inverse) CDFs is defined as follows: @@ -92,20 +92,20 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class shape(Effect): - up: jax.Array = eqx.field(converter=as1darray) - down: jax.Array = eqx.field(converter=as1darray) + up: Array = eqx.field(converter=as1darray) + down: Array = eqx.field(converter=as1darray) apply_op: ClassVar[AddOrMul] = operator.add def __init__( self, - up: jax.Array, - down: jax.Array, + up: Array, + down: Array, ) -> None: self.up = up # +1 sigma self.down = down # -1 sigma - def vshift(self, sf: jax.Array, sumw: jax.Array) -> jax.Array: + def vshift(self, sf: Array, sumw: Array) -> Array: factor = sf dx_sum = self.up + self.down - 2 * sumw dx_diff = self.up - self.down @@ -128,7 +128,7 @@ def vshift(self, sf: jax.Array, sumw: jax.Array) -> jax.Array: def constraint(self) -> HashablePDF: return Gauss(mean=0.0, width=1.0) - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: sf = parameter.value return self.vshift(sf=sf, sumw=sumw) # shift = self.vshift(sf=sf, sumw=sumw) @@ -138,20 +138,18 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class lnN(Effect): - width: tuple[ArrayLike, ArrayLike] = eqx.field(static=True) + width: Float[Array, "2"] = eqx.field(static=True) apply_op: ClassVar[AddOrMul] = operator.mul def __init__( self, - width: tuple[ArrayLike, ArrayLike], + width: Float[Array, "2"], # given as (down, up) ) -> None: - # given as (down, up) - assert isinstance(width, tuple) - assert len(width) == 2 + assert width.shape == (2,) self.width = width - def interpolate(self, parameter: Parameter) -> jax.Array: + def interpolate(self, parameter: Parameter) -> Array: # https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/be488af288361ef101859a398ae618131373cad7/src/ProcessNormalization.cc#L112-L129 x = parameter.value lo, hi = self.width @@ -171,7 +169,7 @@ def interpolate(self, parameter: Parameter) -> jax.Array: def constraint(self) -> HashablePDF: return Gauss(mean=0.0, width=1.0) - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: """ Implementation with (inverse) CDFs is defined as follows: @@ -193,16 +191,16 @@ def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: class poisson(Effect): - lamb: jax.Array = eqx.field(static=True, converter=as1darray) + lamb: Array = eqx.field(static=True, converter=as1darray) apply_op: ClassVar[AddOrMul] = operator.mul - def __init__(self, lamb: jax.Array) -> None: + def __init__(self, lamb: Array) -> None: self.lamb = lamb @property def constraint(self) -> HashablePDF: return Poisson(lamb=self.lamb) - def scale_factor(self, parameter: Parameter, sumw: jax.Array) -> jax.Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: return parameter.value + 1 diff --git a/src/evermore/ipy_util.py b/src/evermore/ipy_util.py deleted file mode 100644 index 6268f8a..0000000 --- a/src/evermore/ipy_util.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Any - -import jax.numpy as jnp - -from evermore.custom_types import ArrayLike -from evermore.model import Model - -__all__ = ["interactive"] - - -def __dir__(): - return __all__ - - -def interactive(model: Model) -> None: - import ipywidgets as widgets - import matplotlib.pyplot as plt - - def slider(v: ArrayLike) -> widgets.FloatSlider: - return widgets.FloatSlider(min=v - 2, max=v + 2, step=0.01, value=v) - - fig, ax = plt.subplots() - - expectation = model.evaluate().expectation() - bins = jnp.arange(expectation.size) - - art = ax.bar(bins, expectation, color="gray") - - @widgets.interact( - **{name: slider(param.value) for name, param in model.parameters.items()} - ) - def update(**kwargs: Any) -> None: - m = model.update(values=kwargs) - res = m.evaluate() - - expectation = res.expectation() - print("Expectation:", expectation) - print("Constraint (logpdf):", m.parameter_constraints()) - - nonlocal art - art.remove() - - art = ax.bar(bins, expectation, color="gray") - - ax.set_xticks(bins) - ax.set_xticklabels(list(map(str, bins))) - ax.set_xlabel(r"Bin #") - ax.set_ylabel(r"S+B model") - plt.tight_layout() - plt.show() diff --git a/src/evermore/likelihood.py b/src/evermore/likelihood.py deleted file mode 100644 index f4c160e..0000000 --- a/src/evermore/likelihood.py +++ /dev/null @@ -1,128 +0,0 @@ -from typing import TYPE_CHECKING, cast - -import equinox as eqx -import jax -import jax.numpy as jnp - -from evermore.custom_types import Sentinel, _NoValue -from evermore.model import Model - -__all__ = [ - "NLL", - "Hessian", - "CovMatrix", - "SampleToy", -] - - -def __dir__(): - return __all__ - - -class BaseModule(eqx.Module): - """ - Base module to hold the `model` and the `observation`. - """ - - model: Model - observation: jax.Array = eqx.field(converter=jnp.asarray) - - def __init__(self, model: Model, observation: jax.Array) -> None: - self.model = model - self.observation = observation - - -class NLL(BaseModule): - """ - Negative log-likelihood (NLL). - """ - - def logpdf(self, *args, **kwargs) -> jax.Array: - return jax.scipy.stats.poisson.logpmf(*args, **kwargs) - - def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array: - if values is _NoValue: - values = self.model.parameter_values - model = self.model.update(values=values) - res = model.evaluate() - nll = jnp.sum( - self.logpdf(self.observation, res.expectation()) - - self.logpdf(self.observation, self.observation), - axis=-1, - ) - # add constraints - constraints = jax.tree_util.tree_leaves(model.parameter_constraints()) - nll += sum(constraints) - nll += model.nll_boundary_penalty() - return -jnp.sum(nll) - - -class Hessian(BaseModule): - """ - Hessian matrix. - """ - - NLL: NLL - - def __init__(self, model: Model, observation: jax.Array) -> None: - super().__init__(model=model, observation=observation) - self.NLL = NLL(model=model, observation=observation) - - def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array: - if values is _NoValue: - values = self.model.parameter_values - if TYPE_CHECKING: - values = cast(dict, values) - hessian = jax.hessian(self.NLL, argnums=0)(values) - hessian, _ = jax.tree_util.tree_flatten(hessian) - hessian = jnp.array(hessian) - new_shape = len(values) - return jnp.reshape(hessian, (new_shape, new_shape)) - - -class CovMatrix(Hessian): - """ - Covariance matrix. - """ - - def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array: - if values is _NoValue: - values = self.model.parameter_values - hessian = super().__call__(values=values) - return jnp.linalg.inv(hessian) - - -class SampleToy(BaseModule): - """ - Sample a toy from the model. - """ - - CovMatrix: CovMatrix - - def __init__(self, model: Model, observation: jax.Array) -> None: - super().__init__(model=model, observation=observation) - self.CovMatrix = CovMatrix(model=model, observation=observation) - - def __call__( - self, - values: dict | Sentinel = _NoValue, - key: jax.Array | Sentinel = _NoValue, - ) -> dict[str, jax.Array]: - if values is _NoValue: - values = self.model.parameter_values - if key is _NoValue: - key = jax.random.PRNGKey(1234) - if TYPE_CHECKING: - key = cast(jax.Array, key) - cov = self.CovMatrix(values=values) - _values, tree_def = jax.tree_util.tree_flatten( - self.model.update(values=values).parameter_values - ) - sampled_values = jax.random.multivariate_normal( - key=key, - mean=jnp.concatenate(_values), - cov=cov, - ) - new_values = jax.tree_util.tree_unflatten(tree_def, sampled_values) - model = self.model.update(values=new_values) - return model.evaluate().expectations diff --git a/src/evermore/loss.py b/src/evermore/loss.py new file mode 100644 index 0000000..2cc487c --- /dev/null +++ b/src/evermore/loss.py @@ -0,0 +1,74 @@ +from collections.abc import Callable + +import equinox as eqx +import jax +import jax.numpy as jnp +from jaxtyping import Array + +from evermore.parameter import Parameter +from evermore.util import _params_map + +__all__ = [ + "get_param_constraints", + "PoissonNLL", +] + + +def __dir__(): + return __all__ + + +def get_param_constraints(module: eqx.Module) -> dict: + constraints = {} + + def _constraint(param: Parameter) -> Array: + if param.constraints: + if len(param.constraints) > 1: + msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}" + raise ValueError(msg) + return next(iter(param.constraints)).logpdf(param.value) + return jnp.array([0.0]) + + # constraints from pdfs + constraints["pdfs"] = _params_map(module, _constraint) + # constraints from boundaries + constraints["boundaries"] = _params_map(module, lambda p: p.boundary_penalty) + return constraints + + +class PoissonNLL(eqx.Module): + """ + Poisson negative log-likelihood (NLL). + + Usage: + + .. code-block:: python + + import evermore as evm + + nll = evm.loss.PoissonNLL() + + def loss(model, x, y): + expectation = model(x) + constraints = evm.loss.get_param_constraints(model) + loss = nll(expectation, y, evm.util.sum_leaves(constraints)) + return loss + """ + + @property + def logpdf(self) -> Callable: + return jax.scipy.stats.poisson.logpmf + + @jax.named_scope("evm.loss.PoissonNLL") + def __call__( + self, expectation: Array, observation: Array, constraint: Array + ) -> Array: + # poisson log-likelihood + nll = jnp.sum( + self.logpdf(observation, expectation) + - self.logpdf(observation, observation), + axis=-1, + ) + # add constraint + nll += constraint + return -jnp.sum(nll) diff --git a/src/evermore/model.py b/src/evermore/model.py deleted file mode 100644 index 24e31cf..0000000 --- a/src/evermore/model.py +++ /dev/null @@ -1,205 +0,0 @@ -from __future__ import annotations - -import abc -from typing import TYPE_CHECKING, Any, cast - -import equinox as eqx -import jax -import jax.numpy as jnp -import jax.tree_util as jtu - -from evermore.custom_types import Sentinel, _NoValue -from evermore.parameter import Parameter -from evermore.util import deep_update - -__all__ = [ - "Result", - "Model", -] - - -def __dir__(): - return __all__ - - -class Result(eqx.Module): - expectations: dict[str, jax.Array] - - def __init__(self) -> None: - self.expectations = {} - - def add(self, process: str, expectation: jax.Array) -> Result: - self.expectations[process] = expectation - return self - - def expectation(self) -> jax.Array: - return cast(jax.Array, sum(jtu.tree_leaves(self.expectations))) - - -def _is_parameter(leaf: Any) -> bool: - return isinstance(leaf, Parameter) - - -def _is_none_or_is_parameter(leaf: Any) -> bool: - return leaf is None or _is_parameter(leaf) - - -class Model(eqx.Module): - """ - A model describing nuisance parameters, templates (histograms), and how they interact. - It is requires to implement the `evaluate` method, which returns an `Result` object. - - Example: - - .. code-block:: python - - import equinox as eqx - import jax - import jax.numpy as jnp - - import evermore as evm - - - # Define a simple model with two processes and two parameters - class MyModel(evm.Model): - def __call__(self, processes: dict, parameters: dict) -> evm.Result: - res = evm.Result() - - # signal - mu_mod = evm.modifier(name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained()) - res.add(process="signal", expectation=mu_mod(processes["signal"])) - - # background - bkg_mod = evm.modifier(name="sigma", parameter=parameters["sigma"], effect=evm.effect.lnN(0.2)) - res.add(process="background", expectation=bkg_mod(processes["background"])) - return res - - - # Setup model - processes = {"signal": jnp.array([10]), "background": jnp.array([50])} - parameters = { - "mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), - "sigma": evm.Parameter(value=jnp.array([0.0])), - } - - model = MyModel(processes=processes, parameters=parameters) - - # evaluate the expectation - model.evaluate().expectation() - # -> Array([60.], dtype=float32) - - %timeit model.evaluate().expectation() - # -> 485 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) - - # evaluate the expectation *fast* - @eqx.filter_jit - def eval(model) -> jax.Array: - res = model.evaluate() - return res.expectation() - - eqx.filter_jit(eval)(model) - # -> Array([60.], dtype=float32) - - %timeit eqx.filter_jit(eval)(model).block_until_ready() - # -> 202 µs ± 4.87 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) - """ - - processes: dict - parameters: dict[str, Parameter] - auxiliary: Any - - def __init__( - self, - processes: dict, - parameters: dict, - auxiliary: Any | Sentinel = _NoValue, - ) -> None: - self.processes = processes - self.parameters = parameters - if auxiliary is _NoValue: - auxiliary = {} - self.auxiliary = auxiliary - - @property - def parameter_values(self) -> dict: - return jtu.tree_map( - lambda l: l.value, # noqa: E741 - self.parameters, - is_leaf=_is_parameter, - ) - - def parameter_constraints(self) -> dict: - def _constraint(param: Parameter) -> jax.Array: - if param.constraints: - if len(param.constraints) > 1: - msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}" - raise ValueError(msg) - return next(iter(param.constraints)).logpdf(param.value) - return jnp.array([0.0]) - - return jtu.tree_map( - _constraint, - self.parameters, - is_leaf=_is_parameter, - ) - - def update( - self, - processes: dict | Sentinel = _NoValue, - values: dict | Sentinel = _NoValue, - ) -> Model: - if values is _NoValue: - values = {} - if processes is _NoValue: - processes = {} - - if TYPE_CHECKING: - values = cast(dict, values) - processes = cast(dict, processes) - - # patch original processes with new ones - new_processes = deep_update(self.processes, processes) - - # patch original parameters with new ones - _updates = deep_update( - jtu.tree_map(lambda _: None, self.parameters, is_leaf=_is_parameter), - values, - ) - - def _update_params(update: jax.Array | None, param: Parameter) -> Parameter: - if update is None: - return param - return param.update(value=update) - - new_parameters = jtu.tree_map( - _update_params, - _updates, - self.parameters, - is_leaf=_is_none_or_is_parameter, - ) - - return eqx.tree_at( - lambda t: (t.processes, t.parameters), self, (new_processes, new_parameters) - ) - - def nll_boundary_penalty(self) -> jax.Array: - return cast( - jax.Array, - sum( - jtu.tree_leaves( - jtu.tree_map( - lambda p: p.boundary_penalty, - self.parameters, - is_leaf=_is_parameter, - ) - ) - ), - ) - - @abc.abstractmethod - def __call__(self, processes: dict, parameters: dict) -> Result: - ... - - def evaluate(self) -> Result: - # evaluate the model with its current state - return self(self.processes, self.parameters) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index d837c0b..422effa 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -8,12 +8,11 @@ import equinox as eqx import jax import jax.numpy as jnp +from jaxtyping import Array from evermore.custom_types import AddOrMul from evermore.effect import ( DEFAULT_EFFECT, - gauss, - poisson, ) from evermore.parameter import Parameter @@ -23,8 +22,8 @@ __all__ = [ "modifier", "compose", - "staterror", - "autostaterrors", + # "staterror", + # "autostaterrors", ] @@ -34,7 +33,7 @@ def __dir__(): class ModifierBase(eqx.Module): @abc.abstractmethod - def __call__(self, sumw: jax.Array) -> jax.Array: + def __call__(self, sumw: Array) -> Array: ... @@ -75,27 +74,27 @@ class modifier(ModifierBase): modify(jnp.array([10, 20, 30])) """ - name: str parameter: Parameter effect: Effect - def __init__( - self, name: str, parameter: Parameter, effect: Effect = DEFAULT_EFFECT - ) -> None: - self.name = name + def __init__(self, parameter: Parameter, effect: Effect = DEFAULT_EFFECT) -> None: self.parameter = parameter self.effect = effect self.parameter.constraints.add(self.effect.constraint) - def scale_factor(self, sumw: jax.Array) -> jax.Array: + def scale_factor(self, sumw: Array) -> Array: return self.effect.scale_factor(parameter=self.parameter, sumw=sumw) - def __call__(self, sumw: jax.Array) -> jax.Array: + @jax.named_scope("evm.modifier") + def __call__(self, sumw: Array) -> Array: op = self.effect.apply_op shift = jnp.atleast_1d(self.scale_factor(sumw=sumw)) shift = jnp.broadcast_to(shift, sumw.shape) return op(shift, sumw) # type: ignore[call-arg] + def __matmul__(self, other: modifier) -> compose: + return compose(self, other) + class compose(ModifierBase): """ @@ -133,7 +132,7 @@ class compose(ModifierBase): eqx.filter_jit(composition)(jnp.array([10, 20, 30])) """ - modifiers: list[ModifierBase] + modifiers: list[modifier] def __init__(self, *modifiers: modifier) -> None: self.modifiers = list(modifiers) @@ -147,19 +146,12 @@ def __init__(self, *modifiers: modifier) -> None: _modifiers.append(mod) self.modifiers = _modifiers - def __check_init__(self): - # check for duplicate names - names = [m.name for m in self.modifiers] - duplicates = {name for name in names if names.count(name) > 1} - if duplicates: - msg = f"Modifiers need to have unique names, got: {duplicates}" - raise ValueError(msg) - def __len__(self) -> int: return len(self.modifiers) - def __call__(self, sumw: jax.Array) -> jax.Array: - def _prep_shift(modifier: ModifierBase, sumw: jax.Array) -> jax.Array: + @jax.named_scope("evm.compose") + def __call__(self, sumw: Array) -> Array: + def _prep_shift(modifier: modifier, sumw: Array) -> Array: shift = modifier.scale_factor(sumw=sumw) shift = jnp.atleast_1d(shift) return jnp.broadcast_to(shift, sumw.shape) @@ -181,257 +173,256 @@ def _prep_shift(modifier: ModifierBase, sumw: jax.Array) -> jax.Array: return _mult_fact * (sumw + _add_shift) -class staterror(ModifierBase): - """ - Create a staterror (barlow-beeston) modifier which acts on each bin with a different _underlying_ modifier. - - *Caution*: The instantiation of a `staterror` is not compatible with JAX-transformations (e.g. `jax.jit`)! - - Example: - - .. code-block:: python - - import jax.numpy as jnp - import evermore as evm - - hist = jnp.array([10, 20, 30]) - - p1 = evm.Parameter(value=1.0) - p2 = evm.Parameter(value=0.0) - p3 = evm.Parameter(value=0.0) - - # all bins with bin content below 10 (threshold) are treated as poisson, else gauss - modify = evm.staterror( - parameters={1: p1, 2: p2, 3: p3}, - sumw=hist, - sumw2=hist, - threshold=10.0, - ) - modify(hist) - # -> Array([13.162277, 20. , 30. ], dtype=float32) - - # jit - import equinox as eqx - - fast_modify = eqx.filter_jit(modify) - """ - - name: str = "staterror" - parameters: dict[str, Parameter] - sumw: jax.Array - sumw2: jax.Array - sumw2sqrt: jax.Array - widths: jax.Array - mask: jax.Array - threshold: float - - def __init__( - self, - parameters: dict[str, Parameter], - sumw: jax.Array, - sumw2: jax.Array, - threshold: float, - ) -> None: - self.parameters = parameters - self.sumw = sumw - self.sumw2 = sumw2 - self.sumw2sqrt = jnp.sqrt(sumw2) - self.threshold = threshold - - # calculate width - self.widths = self.sumw2sqrt / self.sumw - - # store if sumw is below threshold - self.mask = self.sumw < self.threshold - - for i, name in enumerate(self.parameters): - param = self.parameters[name] - effect = poisson(self.sumw[i]) if self.mask[i] else gauss(self.widths[i]) - param.constraints.add(effect.constraint) - - def __check_init__(self): - if not len(self.parameters) == len(self.sumw2) == len(self.sumw): - msg = ( - f"Length of parameters ({len(self.parameters)}), " - f"sumw2 ({len(self.sumw2)}) and sumw ({len(self.sumw)}) " - "must be the same." - ) - raise ValueError(msg) - if not self.threshold > 0.0: - msg = f"Threshold must be >= 0.0, got: {self.threshold}" - raise ValueError(msg) - - def scale_factor(self, sumw: jax.Array) -> jax.Array: - from functools import partial - - assert len(sumw) == len(self.parameters) == len(self.sumw2) - - values = jnp.concatenate([param.value for param in self.parameters.values()]) - idxs = jnp.arange(len(sumw)) - - # sumw where mask (poisson) else widths (gauss) - _widths = jnp.where(self.mask, self.sumw, self.widths) - - def _mod( - value: jax.Array, - width: jax.Array, - idx: jax.Array, - effect: Effect, - ) -> jax.Array: - return effect(width).scale_factor( - parameter=Parameter(value=value), - sumw=sumw[idx], - )[0] - - _poisson_mod = partial(_mod, effect=poisson) - _gauss_mod = partial(_mod, effect=gauss) - - # apply - return jnp.where( - self.mask, - jax.vmap(_poisson_mod)(values, _widths, idxs), - jax.vmap(_gauss_mod)(values, _widths, idxs), - ) - - def __call__(self, sumw: jax.Array) -> jax.Array: - # both gauss and poisson behave multiplicative - op = operator.mul - sf = self.scale_factor(sumw=sumw) - return op(jnp.atleast_1d(sf), sumw) - - -class autostaterrors(eqx.Module): - class Mode(eqx.Enumeration): - barlow_beeston_full = ( - "Barlow-Beeston (full) approach: Poisson per process and bin" - ) - poisson_gauss = "Poisson (Gauss) per process and bin if sumw < (>) threshold" - barlow_beeston_lite = "Barlow-Beeston (lite) approach" - - sumw: dict[str, jax.Array] - sumw2: dict[str, jax.Array] - masks: dict[str, jax.Array] - threshold: float - mode: str - key_template: str = eqx.field(static=True) - - def __init__( - self, - sumw: dict[str, jax.Array], - sumw2: dict[str, jax.Array], - threshold: float = 10.0, - mode: str = Mode.barlow_beeston_lite, - key_template: str = "__staterror_{process}__", - ) -> None: - self.sumw = sumw - self.sumw2 = sumw2 - self.masks = {p: _sumw < threshold for p, _sumw in sumw.items()} - self.threshold = threshold - self.mode = mode - self.key_template = key_template - - def __check_init__(self): - if jax.tree_util.tree_structure(self.sumw) != jax.tree_util.tree_structure( - self.sumw2 - ): # type: ignore[operator] - msg = ( - "The structure of `sumw` and `sumw2` needs to be identical, got " - f"`sumw`: {jax.tree_util.tree_structure(self.sumw)}) and " - f"`sumw2`: {jax.tree_util.tree_structure(self.sumw2)})" - ) - raise ValueError(msg) - if not self.threshold > 0.0: - msg = f"Threshold must be >= 0.0, got: {self.threshold}" - raise ValueError(msg) - if not isinstance(self.mode, self.Mode): - msg = f"Mode must be of type {self.Mode}, got: {self.mode}" - raise ValueError(msg) - - def prepare( - self, - ) -> tuple[dict[str, dict[str, Parameter]], dict[str, dict[str, eqx.Partial]]]: - """ - Helper to automatically create parameters used by `staterror` - for the initialisation of a `evm.Model`. - - *Caution*: This function is not compatible with JAX-transformations (e.g. `jax.jit`)! - - Example: - - .. code-block:: python - - import jax.numpy as jnp - import evermore as evm - - sumw = { - "signal": jnp.array([5, 20, 30]), - "background": jnp.array([5, 20, 30]), - } - - sumw2 = { - "signal": jnp.array([5, 20, 30]), - "background": jnp.array([5, 20, 30]), - } - - - auto = evm.autostaterrors( - sumw=sumw, - sumw2=sumw2, - threshold=10.0, - mode=evm.autostaterrors.Mode.barlow_beeston_full, - ) - parameters, staterrors = auto.prepare() - - # barlow-beeston-lite - auto2 = evm.autostaterrors( - sumw=sumw, - sumw2=sumw2, - threshold=10.0, - mode=evm.autostaterrors.Mode.barlow_beeston_lite, - ) - parameters2, staterrors2 = auto2.prepare() - - # materialize: - process = "signal" - pkey = auto.key_template.format(process=process) - modify = staterrors[pkey](parameters[pkey]) - modified_process = modify(sumw[process]) - """ - import equinox as eqx - - parameters: dict[str, dict[str, Parameter]] = {} - staterrors: dict[str, dict[str, eqx.Partial]] = {} - - for process, _sumw in self.sumw.items(): - key = self.key_template.format(process=process) - process_parameters = parameters[key] = {} - mask = self.masks[process] - for i in range(len(_sumw)): - pkey = f"{process}_{i}" - if self.mode == self.Mode.barlow_beeston_lite and not mask[i]: - # we merge all processes into one parameter - # for the barlow-beeston-lite approach where - # the bin content is above a certain threshold - pkey = f"{i}" - process_parameters[pkey] = Parameter(value=jnp.array(0.0)) - # prepare staterror - kwargs = { - "sumw": _sumw, - "sumw2": self.sumw2[process], - "threshold": self.threshold, - } - if self.mode == self.Mode.barlow_beeston_full: - kwargs["threshold"] = jnp.inf # inf -> always poisson - elif self.mode == self.Mode.barlow_beeston_lite: - kwargs["sumw"] = jnp.where( - mask, - _sumw, - sum(jax.tree_util.tree_leaves(self.sumw)), - ) - kwargs["sumw2"] = jnp.where( - mask, - self.sumw2[process], - sum(jax.tree_util.tree_leaves(self.sumw2)), - ) - staterrors[key] = eqx.Partial(staterror, **kwargs) - return parameters, staterrors +# class staterror(ModifierBase): +# """ +# Create a staterror (barlow-beeston) modifier which acts on each bin with a different _underlying_ modifier. + +# *Caution*: The instantiation of a `staterror` is not compatible with JAX-transformations (e.g. `jax.jit`)! + +# Example: + +# .. code-block:: python + +# import jax.numpy as jnp +# import evermore as evm + +# hist = jnp.array([10, 20, 30]) + +# p1 = evm.Parameter(value=1.0) +# p2 = evm.Parameter(value=0.0) +# p3 = evm.Parameter(value=0.0) + +# # all bins with bin content below 10 (threshold) are treated as poisson, else gauss +# modify = evm.staterror( +# parameters={1: p1, 2: p2, 3: p3}, +# sumw=hist, +# sumw2=hist, +# threshold=10.0, +# ) +# modify(hist) +# # -> Array([13.162277, 20. , 30. ], dtype=float32) + +# # jit +# import equinox as eqx + +# fast_modify = eqx.filter_jit(modify) +# """ + +# parameters: dict[str, Parameter] +# sumw: Array +# sumw2: Array +# sumw2sqrt: Array +# widths: Array +# mask: Array +# threshold: float + +# def __init__( +# self, +# parameters: dict[str, Parameter], +# sumw: Array, +# sumw2: Array, +# threshold: float, +# ) -> None: +# self.parameters = parameters +# self.sumw = sumw +# self.sumw2 = sumw2 +# self.sumw2sqrt = jnp.sqrt(sumw2) +# self.threshold = threshold + +# # calculate width +# self.widths = self.sumw2sqrt / self.sumw + +# # store if sumw is below threshold +# self.mask = self.sumw < self.threshold + +# for i, name in enumerate(self.parameters): +# param = self.parameters[name] +# effect = poisson(self.sumw[i]) if self.mask[i] else gauss(self.widths[i]) +# param.constraints.add(effect.constraint) + +# def __check_init__(self): +# if not len(self.parameters) == len(self.sumw2) == len(self.sumw): +# msg = ( +# f"Length of parameters ({len(self.parameters)}), " +# f"sumw2 ({len(self.sumw2)}) and sumw ({len(self.sumw)}) " +# "must be the same." +# ) +# raise ValueError(msg) +# if not self.threshold > 0.0: +# msg = f"Threshold must be >= 0.0, got: {self.threshold}" +# raise ValueError(msg) + +# def scale_factor(self, sumw: Array) -> Array: +# from functools import partial + +# assert len(sumw) == len(self.parameters) == len(self.sumw2) + +# values = jnp.concatenate([param.value for param in self.parameters.values()]) +# idxs = jnp.arange(len(sumw)) + +# # sumw where mask (poisson) else widths (gauss) +# _widths = jnp.where(self.mask, self.sumw, self.widths) + +# def _mod( +# value: Array, +# width: Array, +# idx: Array, +# effect: type[poisson] | type[gauss], +# ) -> Array: +# return effect(width).scale_factor( +# parameter=Parameter(value=value), +# sumw=sumw[idx], +# )[0] + +# _poisson_mod = partial(_mod, effect=poisson) +# _gauss_mod = partial(_mod, effect=gauss) + +# # apply +# return jnp.where( +# self.mask, +# jax.vmap(_poisson_mod)(values, _widths, idxs), +# jax.vmap(_gauss_mod)(values, _widths, idxs), +# ) + +# def __call__(self, sumw: Array) -> Array: +# # both gauss and poisson behave multiplicative +# op = operator.mul +# sf = self.scale_factor(sumw=sumw) +# return op(jnp.atleast_1d(sf), sumw) + + +# class autostaterrors(eqx.Module): +# class Mode(eqx.Enumeration): +# barlow_beeston_full = ( +# "Barlow-Beeston (full) approach: Poisson per process and bin" +# ) +# poisson_gauss = "Poisson (Gauss) per process and bin if sumw < (>) threshold" +# barlow_beeston_lite = "Barlow-Beeston (lite) approach" + +# sumw: dict[str, Array] +# sumw2: dict[str, Array] +# masks: dict[str, Array] +# threshold: float +# mode: str +# key_template: str = eqx.field(static=True) + +# def __init__( +# self, +# sumw: dict[str, Array], +# sumw2: dict[str, Array], +# threshold: float = 10.0, +# mode: str = Mode.barlow_beeston_lite, +# key_template: str = "__staterror_{process}__", +# ) -> None: +# self.sumw = sumw +# self.sumw2 = sumw2 +# self.masks = {p: _sumw < threshold for p, _sumw in sumw.items()} +# self.threshold = threshold +# self.mode = mode +# self.key_template = key_template + +# def __check_init__(self): +# if jax.tree_util.tree_structure(self.sumw) != jax.tree_util.tree_structure( +# self.sumw2 +# ): # type: ignore[operator] +# msg = ( +# "The structure of `sumw` and `sumw2` needs to be identical, got " +# f"`sumw`: {jax.tree_util.tree_structure(self.sumw)}) and " +# f"`sumw2`: {jax.tree_util.tree_structure(self.sumw2)})" +# ) +# raise ValueError(msg) +# if not self.threshold > 0.0: +# msg = f"Threshold must be >= 0.0, got: {self.threshold}" +# raise ValueError(msg) +# if not isinstance(self.mode, self.Mode): +# msg = f"Mode must be of type {self.Mode}, got: {self.mode}" +# raise ValueError(msg) + +# def prepare( +# self, +# ) -> tuple[dict[str, dict[str, Parameter]], dict[str, dict[str, eqx.Partial]]]: +# """ +# Helper to automatically create parameters used by `staterror` +# for the initialisation of a `evm.Model`. + +# *Caution*: This function is not compatible with JAX-transformations (e.g. `jax.jit`)! + +# Example: + +# .. code-block:: python + +# import jax.numpy as jnp +# import evermore as evm + +# sumw = { +# "signal": jnp.array([5, 20, 30]), +# "background": jnp.array([5, 20, 30]), +# } + +# sumw2 = { +# "signal": jnp.array([5, 20, 30]), +# "background": jnp.array([5, 20, 30]), +# } + + +# auto = evm.autostaterrors( +# sumw=sumw, +# sumw2=sumw2, +# threshold=10.0, +# mode=evm.autostaterrors.Mode.barlow_beeston_full, +# ) +# parameters, staterrors = auto.prepare() + +# # barlow-beeston-lite +# auto2 = evm.autostaterrors( +# sumw=sumw, +# sumw2=sumw2, +# threshold=10.0, +# mode=evm.autostaterrors.Mode.barlow_beeston_lite, +# ) +# parameters2, staterrors2 = auto2.prepare() + +# # materialize: +# process = "signal" +# pkey = auto.key_template.format(process=process) +# modify = staterrors[pkey](parameters[pkey]) +# modified_process = modify(sumw[process]) +# """ +# import equinox as eqx + +# parameters: dict[str, dict[str, Parameter]] = {} +# staterrors: dict[str, dict[str, eqx.Partial]] = {} + +# for process, _sumw in self.sumw.items(): +# key = self.key_template.format(process=process) +# process_parameters = parameters[key] = {} +# mask = self.masks[process] +# for i in range(len(_sumw)): +# pkey = f"{process}_{i}" +# if self.mode == self.Mode.barlow_beeston_lite and not mask[i]: +# # we merge all processes into one parameter +# # for the barlow-beeston-lite approach where +# # the bin content is above a certain threshold +# pkey = f"{i}" +# process_parameters[pkey] = Parameter(value=jnp.array(0.0)) +# # prepare staterror +# kwargs = { +# "sumw": _sumw, +# "sumw2": self.sumw2[process], +# "threshold": self.threshold, +# } +# if self.mode == self.Mode.barlow_beeston_full: +# kwargs["threshold"] = jnp.inf # inf -> always poisson +# elif self.mode == self.Mode.barlow_beeston_lite: +# kwargs["sumw"] = jnp.where( +# mask, +# _sumw, +# sum(jax.tree_util.tree_leaves(self.sumw)), +# ) +# kwargs["sumw2"] = jnp.where( +# mask, +# self.sumw2[process], +# sum(jax.tree_util.tree_leaves(self.sumw2)), +# ) +# staterrors[key] = eqx.Partial(staterror, **kwargs) +# return parameters, staterrors diff --git a/src/evermore/optimizer.py b/src/evermore/optimizer.py deleted file mode 100644 index 7fcfbd2..0000000 --- a/src/evermore/optimizer.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable, Hashable -from typing import TYPE_CHECKING, Any, cast - -import equinox as eqx -import jax -import jaxopt - -from evermore.custom_types import Sentinel, _NoValue - -__all__ = [ - "JaxOptimizer", - "Chain", -] - - -def __dir__(): - return __all__ - - -class JaxOptimizer(eqx.Module): - """ - Wrapper around `jaxopt` optimizers to make them hashable. - This allows to pass the optimizer as a parameter to a `jax.jit` function, and setup the optimizer therein. - - Example: - - .. code-block:: python - - optimizer = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5}) - # or, e.g.: optimizer = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10}) - - optimizer.fit(fun=nll, init_values=init_values) - """ - - name: str - _settings: tuple[tuple[str, Hashable], ...] - - def __init__(self, name: str, _settings: tuple[tuple[str, Hashable], ...]) -> None: - self.name = name - self._settings = _settings - - @classmethod - def make( - cls: type[JaxOptimizer], - name: str, - settings: dict[str, Hashable] | Sentinel = _NoValue, - ) -> JaxOptimizer: - if settings is _NoValue: - settings = {} - if TYPE_CHECKING: - settings = cast(dict[str, Hashable], settings) - return cls(name=name, _settings=tuple(settings.items())) - - @property - def settings(self) -> dict[str, Hashable]: - return dict(self._settings) - - def solver_instance(self, fun: Callable) -> jaxopt._src.base.Solver: - return getattr(jaxopt, self.name)(fun=fun, **self.settings) - - def fit( - self, fun: Callable, init_values: dict[str, jax.Array] - ) -> tuple[dict[str, jax.Array], Any]: - values, state = self.solver_instance(fun=fun).run(init_values) - return values, state - - -class Chain(eqx.Module): - """ - Chain multiple optimizers together. - They probably should have the `maxiter` setting set to a value, - in order to have a deterministic runtime behaviour. - - Example: - - .. code-block:: python - - opt1 = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5}) - opt2 = JaxOptimizer.make(name="LBFGS", settings={"maxiter": 10}) - - chain = Chain(opt1, opt2) - # first 5 steps are minimized with GradientDescent, then 10 steps with LBFGS - chain.fit(fun=nll, init_values=init_values) - """ - - optimizers: tuple[JaxOptimizer, ...] - - def __init__(self, *optimizers: JaxOptimizer) -> None: - self.optimizers = optimizers - - def fit( - self, fun: Callable, init_values: dict[str, jax.Array] - ) -> tuple[dict[str, jax.Array], Any]: - values = init_values - for optimizer in self.optimizers: - values, state = optimizer.fit(fun=fun, init_values=values) - return values, state diff --git a/src/evermore/parameter.py b/src/evermore/parameter.py index a18adb8..195f239 100644 --- a/src/evermore/parameter.py +++ b/src/evermore/parameter.py @@ -1,12 +1,17 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import equinox as eqx -import jax import jax.numpy as jnp +from jaxtyping import Array, ArrayLike, Float from evermore.pdf import HashablePDF from evermore.util import as1darray +if TYPE_CHECKING: + from evermore.modifier import modifier + __all__ = [ "Parameter", ] @@ -17,28 +22,57 @@ def __dir__(): class Parameter(eqx.Module): - value: jax.Array = eqx.field(converter=as1darray) - bounds: tuple[jax.Array, jax.Array] = eqx.field( - static=True, converter=lambda x: tuple(map(as1darray, x)) - ) + value: Array = eqx.field(converter=as1darray) + lower: Array = eqx.field(static=True, converter=as1darray) + upper: Array = eqx.field(static=True, converter=as1darray) constraints: set[HashablePDF] = eqx.field(static=True) def __init__( self, - value: jax.Array, - bounds: tuple[jax.Array, jax.Array] = (as1darray(-jnp.inf), as1darray(jnp.inf)), + value: ArrayLike = 0.0, + lower: ArrayLike = -jnp.inf, + upper: ArrayLike = jnp.inf, ) -> None: - self.value = value - self.bounds = bounds + self.value = as1darray(value) + self.lower = as1darray(lower) + self.upper = as1darray(upper) self.constraints: set[HashablePDF] = set() - def update(self, value: jax.Array) -> Parameter: + def update(self, value: Array | Parameter) -> Parameter: + if isinstance(value, Parameter): + value = value.value return eqx.tree_at(lambda t: t.value, self, value) @property - def boundary_penalty(self) -> jax.Array: + def boundary_penalty(self) -> Array: return jnp.where( - (self.value < self.bounds[0]) | (self.value > self.bounds[1]), + (self.value < self.lower) | (self.value > self.upper), jnp.inf, 0, ) + + # shorthands + def unconstrained(self) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.unconstrained()) + + def gauss(self, width: Array) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.gauss(width=width)) + + def lnN(self, width: Float[Array, 2]) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.lnN(width=width)) + + def poisson(self, lamb: Array) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.poisson(lamb=lamb)) + + def shape(self, up: Array, down: Array) -> modifier: + import evermore as evm + + return evm.modifier(parameter=self, effect=evm.effect.shape(up=up, down=down)) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index 3cb804b..f2d4542 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -1,10 +1,15 @@ from __future__ import annotations from abc import abstractmethod +from typing import TYPE_CHECKING, Any import equinox as eqx import jax import jax.numpy as jnp +from jaxtyping import Array, PRNGKeyArray + +if TYPE_CHECKING: + from evermore import Parameter __all__ = [ "HashablePDF", @@ -24,19 +29,23 @@ def __hash__(self) -> int: ... @abstractmethod - def logpdf(self, x: jax.Array) -> jax.Array: + def logpdf(self, x: Array) -> Array: + ... + + @abstractmethod + def pdf(self, x: Array) -> Array: ... @abstractmethod - def pdf(self, x: jax.Array) -> jax.Array: + def cdf(self, x: Array) -> Array: ... @abstractmethod - def cdf(self, x: jax.Array) -> jax.Array: + def inv_cdf(self, x: Array) -> Array: ... @abstractmethod - def inv_cdf(self, x: jax.Array) -> jax.Array: + def sample(self, key: PRNGKeyArray, parameter: Parameter) -> Array: ... @@ -44,19 +53,29 @@ class Flat(HashablePDF): def __hash__(self): return hash(self.__class__) - def logpdf(self, x: jax.Array) -> jax.Array: + def logpdf(self, x: Array) -> Array: return jnp.array([0.0]) - def pdf(self, x: jax.Array) -> jax.Array: + def pdf(self, x: Array) -> Array: return jnp.array([1.0]) - def cdf(self, x: jax.Array) -> jax.Array: + def cdf(self, x: Array) -> Array: return jnp.array([1.0]) - def inv_cdf(self, x: jax.Array) -> jax.Array: + def inv_cdf(self, x: Array) -> Array: msg = "Flat distribution has no inverse CDF." raise ValueError(msg) + def sample(self, key: PRNGKeyArray, parameter: Parameter) -> Array: + return jax.random.uniform( + key, + parameter.value.shape, + # what should be the ranges? + # +/-jnp.inf leads to nans... + # minval=parameter.lower, + # maxval=parameter.upper, + ) + class Gauss(HashablePDF): mean: float = eqx.field(static=True) @@ -69,44 +88,59 @@ def __init__(self, mean: float, width: float) -> None: def __hash__(self): return hash(self.__class__) ^ hash((self.mean, self.width)) - def logpdf(self, x: jax.Array) -> jax.Array: + def logpdf(self, x: Array) -> Array: logpdf_max = jax.scipy.stats.norm.logpdf( self.mean, loc=self.mean, scale=self.width ) unnormalized = jax.scipy.stats.norm.logpdf(x, loc=self.mean, scale=self.width) return unnormalized - logpdf_max - def pdf(self, x: jax.Array) -> jax.Array: + def pdf(self, x: Array) -> Array: return jax.scipy.stats.norm.pdf(x, loc=self.mean, scale=self.width) - def cdf(self, x: jax.Array) -> jax.Array: + def cdf(self, x: Array) -> Array: return jax.scipy.stats.norm.cdf(x, loc=self.mean, scale=self.width) - def inv_cdf(self, x: jax.Array) -> jax.Array: + def inv_cdf(self, x: Array) -> Array: return jax.scipy.stats.norm.ppf(x, loc=self.mean, scale=self.width) + def sample(self, key: PRNGKeyArray, parameter: Parameter) -> Array: + return self.mean + self.width * jax.random.normal( + key, + shape=parameter.value.shape, + dtype=parameter.value.dtype, + ) + class Poisson(HashablePDF): - lamb: jax.Array = eqx.field(static=True) + lamb: Array = eqx.field(static=True) - def __init__(self, lamb: jax.Array) -> None: + def __init__(self, lamb: Array) -> None: self.lamb = lamb def __hash__(self): - return hash(self.__class__) ^ hash(str(self.lamb)) # is this a safe hash?? + return hash(self.__class__) - def logpdf(self, x: jax.Array) -> jax.Array: + def __eq__(self, other: Any): # type: ignore[override] + if not isinstance(other, Poisson): + return ValueError(f"Cannot compare Poisson with {type(other)}") + # We need to implement __eq__ explicitely because we have a non-hashable field (lamb). + # Implementing __eq__ is necessary for the `==` operator to work and to ensure that the + # Poisson distribution is correctly added to a python set. + return jnp.all(self.lamb == other.lamb) + + def logpdf(self, x: Array) -> Array: logpdf_max = jax.scipy.stats.poisson.logpmf(self.lamb, mu=self.lamb) unnormalized = jax.scipy.stats.poisson.logpmf((x + 1) * self.lamb, mu=self.lamb) return unnormalized - logpdf_max - def pdf(self, x: jax.Array) -> jax.Array: + def pdf(self, x: Array) -> Array: return jax.scipy.stats.poisson.pmf((x + 1) * self.lamb, mu=self.lamb) - def cdf(self, x: jax.Array) -> jax.Array: + def cdf(self, x: Array) -> Array: return jax.scipy.stats.poisson.cdf((x + 1) * self.lamb, mu=self.lamb) - def inv_cdf(self, x: jax.Array) -> jax.Array: + def inv_cdf(self, x: Array) -> Array: # see: https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html?highlight=poisson%20inverse#5.3-Example:-Left-truncated-Poisson def cond_fn(val): n, cdf = val @@ -121,3 +155,11 @@ def body_fn(val): cdf_start = self.cdf(start) n, _ = jax.lax.while_loop(cond_fn, body_fn, (start, cdf_start)) return n.astype(jnp.result_type(int)) + + def sample(self, key: PRNGKeyArray) -> Array: # type: ignore[override] + return jax.random.poisson( + key, + self.lamb, + shape=self.lamb.shape, + dtype=self.lamb.dtype, + ) diff --git a/src/evermore/sample.py b/src/evermore/sample.py new file mode 100644 index 0000000..13d0ec9 --- /dev/null +++ b/src/evermore/sample.py @@ -0,0 +1,36 @@ +from collections.abc import Callable + +import equinox as eqx +import jax +from jaxtyping import Array, PRNGKeyArray, PyTree + +from evermore.util import is_parameter + + +# get the PDFs from the parameters of the model +def toy_module(module: eqx.Module, key: PRNGKeyArray) -> PyTree[Callable]: + from evermore import Parameter + + params_tree = eqx.filter(module, is_parameter, is_leaf=is_parameter) + params_structure = jax.tree_util.tree_structure(params_tree) + n_params = params_structure.num_leaves # type: ignore[attr-defined] + + keys = jax.random.split(key, n_params) + keys_tree = jax.tree_util.tree_unflatten(params_structure, keys) + + def _sample(param: Parameter, key: Parameter) -> Array: + if not param.constraints: + msg = f"Parameter {param} has no constraint pdf, can't sample from it." + raise RuntimeError(msg) + if len(param.constraints) > 1: + msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}" + raise ValueError(msg) + pdf = next(iter(param.constraints)) + + # sample new value from the constraint pdf + sampled_param_value = pdf.sample(key.value, param) + + # replace the sampled parameter value and return new parameter + return eqx.tree_at(lambda p: p.value, param, sampled_param_value) + + return jax.tree_util.tree_map(_sample, params_tree, keys_tree, is_leaf=is_parameter) diff --git a/src/evermore/util.py b/src/evermore/util.py index 8ce79b2..db70722 100644 --- a/src/evermore/util.py +++ b/src/evermore/util.py @@ -1,22 +1,24 @@ from __future__ import annotations -import collections -import pprint -from collections.abc import Callable, Hashable, Iterable, Mapping -from typing import TYPE_CHECKING, Any, TypeVar, cast - +import operator +from collections.abc import Callable +from functools import partial +from typing import ( + Any, +) + +import equinox as eqx import jax import jax.numpy as jnp - -from evermore.custom_types import ArrayLike, Sentinel, _NoValue +import jax.tree_util as jtu +from jaxtyping import Array, ArrayLike, PyTree __all__ = [ - "HistDB", - "FrozenDB", + "is_parameter", + "sum_leaves", "as1darray", "dump_hlo_graph", "dump_jaxpr", - "deep_update", ] @@ -24,226 +26,37 @@ def __dir__(): return __all__ -class FrozenKeysView(collections.abc.KeysView): - """FrozenKeysView that does not print values when repr'ing.""" - - def __init__(self, mapping): - super().__init__(mapping) - self._mapping = mapping - - def __repr__(self): - return f"{type(self).__name__}({list(map(_pretty_key, self._mapping.keys()))})" - - __str__ = __repr__ - - -def _pretty_key(key): - if not isinstance(key, frozenset): - key = FrozenDB.keyify(key) - if len(key) == 1: - return next(iter(key)) - return tuple([_pretty_key(k) for k in key]) - - -def _indent(amount: int, s: str) -> str: - """Indents `s` with `amount` spaces.""" - prefix = amount * " " - return "\n".join(prefix + line for line in s.splitlines()) - - -def _pretty_dict(x): - if not isinstance(x, Mapping): - return pprint.pformat(x) - rep = "" - for key, val in x.items(): - rep += f"{_pretty_key(key)!r}: {_pretty_dict(val)},\n" - if rep: - return "{\n" + _indent(2, rep) + "\n}" - return "{}" - - -K = TypeVar("K") -V = TypeVar("V") - - -def _prepare_freeze(xs: Any) -> Any: - """Deep copy unfrozen dicts to make the dictionary FrozenDict safe.""" - if isinstance(xs, FrozenDB): - # we can safely ref share the internal state of a FrozenDict - # because it is immutable. - return xs._dict - if not isinstance(xs, dict): - # return a leaf as is. - return xs - # recursively copy dictionary to avoid ref sharing - return {FrozenDB.keyify(key): _prepare_freeze(val) for key, val in xs.items()} - - -def _check_no_duplicate_keys(keys: Iterable[Hashable]) -> None: - keys = list(keys) - if any(keys.count(x) > 1 for x in keys): - msg = f"Duplicate keys: {tuple(keys)}, this is not allowed!" - raise ValueError(msg) - - -class FrozenDB(Mapping[K, V]): - """An immutable database-like custom dict. - - Example: - - .. code-block:: python - - hists = HistDB( - { - # QCD - ("QCD", "nominal"): jnp.array([1, 1, 1, 1, 1]), - ("QCD", "JES", "Up"): jnp.array([1.5, 1.5, 1.5, 1.5, 1.5]), - ("QCD", "JES", "Down"): jnp.array([0.5, 0.5, 0.5, 0.5, 0.5]), - # DY - ("DY", "nominal"): jnp.array([2, 2, 2, 2, 2]), - ("DY", "JES", "Up"): jnp.array([2.5, 2.5, 2.5, 2.5, 2.5]), - ("DY", "JES", "Down"): jnp.array([0.7, 0.7, 0.7, 0.7, 0.7]), - } - ) - - print(hists) - # -> HistDB({ - # ('QCD', 'nominal'): Array([1, 1, 1, 1, 1], dtype=int32), - # ('QCD', 'Up', 'JES'): Array([1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32), - # ('QCD', 'Down', 'JES'): Array([0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32), - # ('DY', 'nominal'): Array([2, 2, 2, 2, 2], dtype=int32), - # ('DY', 'Up', 'JES'): Array([2.5, 2.5, 2.5, 2.5, 2.5], dtype=float32), - # ('DY', 'Down', 'JES'): Array([0.7, 0.7, 0.7, 0.7, 0.7], dtype=float32), - # }) - - print(hists["QCD"]) - # -> HistDB({ - # 'nominal': Array([1, 1, 1, 1, 1], dtype=int32), - # ('Up', 'JES'): Array([1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32), - # ('Down', 'JES'): Array([0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32), - # }) - - print(hists["JES"]) - # -> HistDB({ - # ('QCD', 'Up'): Array([1.5, 1.5, 1.5, 1.5, 1.5], dtype=float32), - # ('QCD', 'Down'): Array([0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32), - # ('DY', 'Up'): Array([2.5, 2.5, 2.5, 2.5, 2.5], dtype=float32), - # ('DY', 'Down'): Array([0.7, 0.7, 0.7, 0.7, 0.7], dtype=float32), - # }) - - # It's jit-compatible: - def foo(hists): - return (hists["QCD", "nominal"] + 1.2) ** 2 - - print(jax.jit(foo)(hists)) - # -> Array([4.84, 4.84, 4.84, 4.84, 4.84], dtype=float32, weak_type=True) - """ - - __slots__ = ("_dict",) - - if TYPE_CHECKING: - _dict: dict[frozenset, Any] - - @staticmethod - def keyify(keyish: Any) -> frozenset: - if not isinstance(keyish, tuple | list | set | frozenset): - keyish = (keyish,) - _check_no_duplicate_keys(keyish) - keyish = frozenset(keyish) - assert not any(isinstance(key, set) for key in keyish) - return keyish - - def __init__( - self, - xs: Mapping | Sentinel = _NoValue, - __unsafe_skip_copy__: bool = False, - ) -> None: - # make sure the dict is as - if xs is _NoValue: - xs = {} - data = dict(cast(Mapping, xs)) - if __unsafe_skip_copy__: - self._dict = data - else: - self._dict = _prepare_freeze(data) - - def __getitem__(self, key) -> Any: - key = self.keyify(key) - if key in self._dict: - return self._dict[key] - ret = self.__class__({k - key: v for k, v in self.items() if key <= k}) - if not ret: - raise KeyError(key) - return ret - - def __setitem__(self, key, value) -> None: - msg = f"{type(self).__name__} is immutable." - raise ValueError(msg) - - def __contains__(self, key) -> bool: - key = self.keyify(key) - return key in self._dict - - def __len__(self) -> int: - return len(self._dict) - - def __iter__(self): - return iter(self._dict) - - def keys(self) -> FrozenKeysView: - return FrozenKeysView(self._dict) +def is_parameter(leaf: Any) -> bool: + from evermore import Parameter - def values(self): - return self._dict.values() + return isinstance(leaf, Parameter) - def items(self): - for key in self._dict: - yield (key, self[key]) - def only(self, *keys) -> FrozenDB: - return self.__class__({key: self[key] for key in keys}) +K = str +V = Any - def subset(self, *keys) -> FrozenDB: - new = {} - for key in keys: - new.update({k: v for k, v in self.items() if self.keyify(key) <= k}) - return self.__class__(new) - def copy(self) -> FrozenDB: - return self.__class__(self) - - def __repr__(self) -> str: - return f"{type(self).__name__}({_pretty_dict(self._dict)})" - - def as_compact_dict(self): - return {"/".join(sorted(map(str, k))): v for k, v in self.items()} - - -def _flatten(tree): - return (tuple(tree.values()), tuple(tree.keys())) - - -def _make_unflatten(cls: type[FrozenDB]) -> Callable: - def _unflatten(keys, values): - return cls(dict(zip(keys, values, strict=True)), __unsafe_skip_copy__=True) - - return _unflatten +def _filtered_module_map( + module: eqx.Module, + fun: Callable, + filter: Callable, +) -> eqx.Module: + params = eqx.filter(module, filter, is_leaf=filter) + return jtu.tree_map( + fun, + params, + is_leaf=filter, + ) -class HistDB(FrozenDB): - ... +_params_map = partial(_filtered_module_map, filter=is_parameter) -# then we register them with jax as a PyTree -for cls in HistDB, FrozenDB: - jax.tree_util.register_pytree_node( - cls, - _flatten, - _make_unflatten(cls), - ) +def sum_leaves(tree: PyTree) -> Array: + return jtu.tree_reduce(operator.add, tree) -def as1darray(x: ArrayLike) -> jax.Array: +def as1darray(x: ArrayLike) -> Array: """ Converts `x` to a 1d array. @@ -316,20 +129,3 @@ def f(x: jax.Array) -> jax.Array: filepath.write_text(dump_hlo_graph(f, x), encoding='ascii') """ return jax.xla_computation(fun)(*args, **kwargs).as_hlo_dot_graph() - - -def deep_update( - mapping: dict[K, Any], - new_mapping: dict[K, Any], -) -> dict[K, Any]: - updated_mapping = mapping.copy() - for k, v in new_mapping.items(): - if ( - k in updated_mapping - and isinstance(updated_mapping[k], dict) - and isinstance(v, dict) - ): - updated_mapping[k] = deep_update(updated_mapping[k], v) - else: - updated_mapping[k] = v - return updated_mapping diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py deleted file mode 100644 index f328e8e..0000000 --- a/tests/test_optimizer.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from functools import partial - -import jax -import jaxopt -import pytest - -from evermore.optimizer import JaxOptimizer - - -def test_jaxoptimizer(): - opt = JaxOptimizer.make(name="GradientDescent", settings={"maxiter": 5}) - - assert opt.name == "GradientDescent" - assert opt.settings == {"maxiter": 5} - - assert isinstance(opt.solver_instance(fun=lambda x: x), jaxopt.GradientDescent) - - # jit compatibility - @partial(jax.jit, static_argnums=0) - def f(optimizer): - @jax.jit - def fun(x): - return (x - 2.0) ** 2 - - init_values = 1.0 - values, _ = optimizer.fit(fun=fun, init_values=init_values) - return values - - assert f(opt) == pytest.approx(2.0) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 4b6e3f0..3147288 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -8,11 +8,12 @@ def test_parameter(): - p = evm.Parameter(value=jnp.array(1.0), bounds=(jnp.array(0.0), jnp.array(2.0))) + p = evm.Parameter(value=jnp.array(1.0), lower=jnp.array(0.0), upper=jnp.array(2.0)) assert p.value == 1.0 assert p.update(jnp.array(2.0)).value == 2.0 - assert p.bounds == (0.0, 2.0) + assert p.lower == 0.0 + assert p.upper == 2.0 assert p.boundary_penalty == 0.0 assert p.update(jnp.array(3.0)).boundary_penalty == jnp.inf @@ -42,7 +43,7 @@ def test_gauss(): def test_lnN(): p = evm.Parameter(value=jnp.array(0.0)) - ln = evm.effect.lnN(width=(0.9, 1.1)) + ln = evm.effect.lnN(width=jnp.array([0.9, 1.1])) assert ln.constraint == Gauss(mean=0.0, width=1.0) assert ln.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) @@ -67,21 +68,15 @@ def test_modifier(): norm = evm.Parameter(value=jnp.array(0.0)) # unconstrained effect - m_unconstrained = evm.modifier( - name="mu", parameter=mu, effect=evm.effect.unconstrained() - ) + m_unconstrained = evm.modifier(parameter=mu, effect=evm.effect.unconstrained()) assert m_unconstrained(jnp.array([10])) == pytest.approx(11) # gauss effect - m_gauss = evm.modifier( - name="norm", parameter=norm, effect=evm.effect.gauss(jnp.array(0.1)) - ) + m_gauss = evm.modifier(parameter=norm, effect=evm.effect.gauss(jnp.array(0.1))) assert m_gauss(jnp.array([10])) == pytest.approx(10) # lnN effect - m_lnN = evm.modifier( - name="norm", parameter=norm, effect=evm.effect.lnN(width=(0.9, 1.1)) - ) + m_lnN = evm.modifier(parameter=norm, effect=evm.effect.lnN(width=(0.9, 1.1))) assert m_lnN(jnp.array([10])) == pytest.approx(10) # poisson effect # FIXME @@ -99,13 +94,9 @@ def test_compose(): norm = evm.Parameter(value=jnp.array(0.0)) # unconstrained effect - m_unconstrained = evm.modifier( - name="mu", parameter=mu, effect=evm.effect.unconstrained() - ) + m_unconstrained = evm.modifier(parameter=mu, effect=evm.effect.unconstrained()) # gauss effect - m_gauss = evm.modifier( - name="norm", parameter=norm, effect=evm.effect.gauss(jnp.array(0.1)) - ) + m_gauss = evm.modifier(parameter=norm, effect=evm.effect.gauss(jnp.array(0.1))) # compose m = evm.compose(m_unconstrained, m_gauss) diff --git a/tests/test_util.py b/tests/test_util.py index 75afe94..9e50171 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,46 +2,7 @@ import jax -from evermore.util import FrozenDB, as1darray - - -def get_frozendb(): - return FrozenDB( - { - # QCD - ("a", "b"): 1, - ("a", "d", "e"): 2, - ("a", "d", "f"): 3, - # DY - ("g", "b"): 4, - ("g", "d", "e"): 5, - ("g", "d", "f"): 6, - } - ) - - -def test_frozendb_len(): - db = get_frozendb() - - assert len(db) == 6 - - -def test_frozendb_getitem(): - db = get_frozendb() - - assert db["a"]["b"] == 1 - assert db["a", "b"] == 1 - assert db["b"] == FrozenDB({"a": 1, "g": 4}) - - -def test_frozendb_jitcompatible(): - db = get_frozendb() - - @jax.jit - def fun(db): - return (db["a", "b"] + 1) ** 2 - - assert fun(db) == 4 +from evermore.util import as1darray def test_as1darray(): From 039e311a58f10155efba288dbe8b08ada9342fd6 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 6 Mar 2024 15:10:14 +0100 Subject: [PATCH 02/22] fix tests --- tests/test_parameter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 3147288..987e66c 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -76,7 +76,9 @@ def test_modifier(): assert m_gauss(jnp.array([10])) == pytest.approx(10) # lnN effect - m_lnN = evm.modifier(parameter=norm, effect=evm.effect.lnN(width=(0.9, 1.1))) + m_lnN = evm.modifier( + parameter=norm, effect=evm.effect.lnN(width=jnp.array([0.9, 1.1])) + ) assert m_lnN(jnp.array([10])) == pytest.approx(10) # poisson effect # FIXME From 32e966bb6e2bd9cc00e3885d36ced53deb116c41 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 6 Mar 2024 15:10:49 +0100 Subject: [PATCH 03/22] sample: account for non-evm.Parameters in model --- src/evermore/sample.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/evermore/sample.py b/src/evermore/sample.py index 13d0ec9..43fc87d 100644 --- a/src/evermore/sample.py +++ b/src/evermore/sample.py @@ -11,7 +11,7 @@ def toy_module(module: eqx.Module, key: PRNGKeyArray) -> PyTree[Callable]: from evermore import Parameter - params_tree = eqx.filter(module, is_parameter, is_leaf=is_parameter) + params_tree, rest_tree = eqx.partition(module, is_parameter, is_leaf=is_parameter) params_structure = jax.tree_util.tree_structure(params_tree) n_params = params_structure.num_leaves # type: ignore[attr-defined] @@ -33,4 +33,10 @@ def _sample(param: Parameter, key: Parameter) -> Array: # replace the sampled parameter value and return new parameter return eqx.tree_at(lambda p: p.value, param, sampled_param_value) - return jax.tree_util.tree_map(_sample, params_tree, keys_tree, is_leaf=is_parameter) + # sample for each parameter + sampled_params_tree = jax.tree_util.tree_map( + _sample, params_tree, keys_tree, is_leaf=is_parameter + ) + + # combine the sampled parameters with the rest of the model and return it + return eqx.combine(sampled_params_tree, rest_tree, is_leaf=is_parameter) From eaa7e47e4433faf86a868d51b94fb74c2ad48fe0 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 6 Mar 2024 15:11:21 +0100 Subject: [PATCH 04/22] better descriptions of dnn_weights_constraint example --- examples/dnn_weights_constraint.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/dnn_weights_constraint.py b/examples/dnn_weights_constraint.py index f18f145..62d2a4c 100644 --- a/examples/dnn_weights_constraint.py +++ b/examples/dnn_weights_constraint.py @@ -10,9 +10,12 @@ class LinearConstrained(eqx.Module): biases: jax.Array def __init__(self, in_size, out_size, key): - self.biases = jax.random.normal(key, (out_size,)) - self.weights = evm.Parameter(value=jax.random.normal(key, (out_size, in_size))) + wkey, bkey = jax.random.split(key) + # weights + self.weights = evm.Parameter(value=jax.random.normal(wkey, (out_size, in_size))) self.weights.constraints.add(evm.pdf.Gauss(mean=0.0, width=0.5)) + # biases + self.biases = jax.random.normal(bkey, (out_size,)) def __call__(self, x: jax.Array): return self.weights.value @ x + self.biases From 892540cd0d1d5c124b8551619885f47589adcba4 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 6 Mar 2024 15:11:56 +0100 Subject: [PATCH 05/22] add example for NLL profiling --- examples/nll_profiling.py | 66 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/examples/nll_profiling.py b/examples/nll_profiling.py index e69de29..ea37ff5 100644 --- a/examples/nll_profiling.py +++ b/examples/nll_profiling.py @@ -0,0 +1,66 @@ +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.tree_util as jtu +import optax +from jaxtyping import Array + +import evermore as evm + + +def fixed_mu_fit(mu: Array) -> Array: + from model import hists, model, observation + + nll = evm.loss.PoissonNLL() + + optim = optax.sgd(learning_rate=1e-2) + opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) + + model = eqx.tree_at(lambda t: t.mu.value, model, mu) + + # filter out mu from the model (no gradients will be calculated for mu!) + # see: https://github.com/patrick-kidger/equinox/blob/main/examples/frozen_layer.ipynb + filter_spec = jtu.tree_map(lambda _: True, model) + filter_spec = eqx.tree_at( + lambda tree: tree.mu.value, + filter_spec, + replace=False, + ) + + @eqx.filter_jit + def loss(diff_model, static_model, hists, observation): + model = eqx.combine(diff_model, static_model) + expectations = model(hists) + constraints = evm.loss.get_param_constraints(model) + return nll( + expectation=evm.util.sum_leaves(expectations), + observation=observation, + constraint=evm.util.sum_leaves(constraints), + ) + + @eqx.filter_jit + def make_step(model, opt_state, events, observation): + # differentiate + diff_model, static_model = eqx.partition(model, filter_spec) + grads = eqx.filter_grad(loss)(diff_model, static_model, events, observation) + updates, opt_state = optim.update(grads, opt_state) + # apply nuisance parameter and DNN weight updates + model = eqx.apply_updates(model, updates) + return model, opt_state + + # minimize model with 1000 steps + for _ in range(1000): + model, opt_state = make_step(model, opt_state, hists, observation) + diff_model, static_model = eqx.partition(model, filter_spec) + return loss(diff_model, static_model, hists, observation) + + +mus = jnp.linspace(0, 5, 11) +# for loop over mu values +for mu in mus: + print(f"[for-loop] mu={mu:.2f} - NLL={fixed_mu_fit(jnp.array(mu)):.6f}") + +# or vectorized!!! +likelihood_scan = jax.vmap(fixed_mu_fit)(mus) +for mu, nll in zip(mus, likelihood_scan, strict=False): + print(f"[vectorized] mu={mu:.2f} - NLL={nll:.6f}") From 0a3a4cf5c1567b0bf63d6c19a038d5161bb6fa2c Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 6 Mar 2024 15:29:11 +0100 Subject: [PATCH 06/22] polish examples; add jax as dependency; better error in sample function --- examples/grad_nll.py | 24 ++++++++++++++++++++++++ examples/model.py | 18 +----------------- examples/nll_fit.py | 39 +++++++++++++++++++++++++++++++++++++++ examples/nll_profiling.py | 9 ++++++--- pyproject.toml | 2 ++ src/evermore/loss.py | 9 ++------- src/evermore/sample.py | 2 +- 7 files changed, 75 insertions(+), 28 deletions(-) diff --git a/examples/grad_nll.py b/examples/grad_nll.py index e69de29..0850829 100644 --- a/examples/grad_nll.py +++ b/examples/grad_nll.py @@ -0,0 +1,24 @@ +import equinox as eqx +import jax.numpy as jnp +from model import hists, model, observation + +import evermore as evm + +nll = evm.loss.PoissonNLL() + + +@eqx.filter_jit +def loss(model, hists, observation): + expectations = model(hists) + constraints = evm.loss.get_param_constraints(model) + loss_val = nll( + expectation=evm.util.sum_leaves(expectations), + observation=observation, + ) + # add constraint + loss_val += evm.util.sum_leaves(constraints) + return -jnp.sum(loss_val) + + +loss_val = loss(model, hists, observation) +grads = eqx.filter_grad(loss)(model, hists, observation) diff --git a/examples/model.py b/examples/model.py index be416a0..25cebf8 100644 --- a/examples/model.py +++ b/examples/model.py @@ -60,20 +60,4 @@ def __call__(self, hists: dict[Any, jax.Array]) -> dict[str, jax.Array]: } observation = jnp.array([37]) - -nll = evm.loss.PoissonNLL() - - -@eqx.filter_jit -def loss(model, hists, observation): - expectations = model(hists) - constraints = evm.loss.get_param_constraints(model) - return nll( - expectation=evm.util.sum_leaves(expectations), - observation=observation, - constraint=evm.util.sum_leaves(constraints), - ) - - -loss_val = loss(model, hists, observation) -grads = eqx.filter_grad(loss)(model, hists, observation) +expectations = model(hists) diff --git a/examples/nll_fit.py b/examples/nll_fit.py index e69de29..3d42579 100644 --- a/examples/nll_fit.py +++ b/examples/nll_fit.py @@ -0,0 +1,39 @@ +import equinox as eqx +import jax.numpy as jnp +import optax +from model import hists, model, observation + +import evermore as evm + +optim = optax.sgd(learning_rate=1e-2) +opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) + +nll = evm.loss.PoissonNLL() + + +@eqx.filter_jit +def loss(model, hists, observation): + expectations = model(hists) + constraints = evm.loss.get_param_constraints(model) + loss_val = nll( + expectation=evm.util.sum_leaves(expectations), + observation=observation, + ) + # add constraint + loss_val += evm.util.sum_leaves(constraints) + return -jnp.sum(loss_val) + + +@eqx.filter_jit +def make_step(model, opt_state, events, observation): + # differentiate full analysis + grads = eqx.filter_grad(loss)(model, events, observation) + updates, opt_state = optim.update(grads, opt_state) + # apply nuisance parameter and DNN weight updates + model = eqx.apply_updates(model, updates) + return model, opt_state + + +# minimize model with 1000 steps +for _ in range(1000): + model, opt_state = make_step(model, opt_state, hists, observation) diff --git a/examples/nll_profiling.py b/examples/nll_profiling.py index ea37ff5..01271e9 100644 --- a/examples/nll_profiling.py +++ b/examples/nll_profiling.py @@ -32,11 +32,13 @@ def loss(diff_model, static_model, hists, observation): model = eqx.combine(diff_model, static_model) expectations = model(hists) constraints = evm.loss.get_param_constraints(model) - return nll( + loss_val = nll( expectation=evm.util.sum_leaves(expectations), observation=observation, - constraint=evm.util.sum_leaves(constraints), ) + # add constraint + loss_val += evm.util.sum_leaves(constraints) + return -2 * jnp.sum(loss_val) @eqx.filter_jit def make_step(model, opt_state, events, observation): @@ -60,7 +62,8 @@ def make_step(model, opt_state, events, observation): for mu in mus: print(f"[for-loop] mu={mu:.2f} - NLL={fixed_mu_fit(jnp.array(mu)):.6f}") + # or vectorized!!! likelihood_scan = jax.vmap(fixed_mu_fit)(mus) for mu, nll in zip(mus, likelihood_scan, strict=False): - print(f"[vectorized] mu={mu:.2f} - NLL={nll:.6f}") + print(f"[jax.vmap] mu={mu:.2f} - NLL={nll:.6f}") diff --git a/pyproject.toml b/pyproject.toml index 34223e0..168df55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ classifiers = [ ] dynamic = ["version"] # version is set in src/evermore/__init__.py dependencies = [ + "jax", + "jaxtyping", "equinox>=0.10.6", # eqx.field ] diff --git a/src/evermore/loss.py b/src/evermore/loss.py index 2cc487c..e638aa9 100644 --- a/src/evermore/loss.py +++ b/src/evermore/loss.py @@ -60,15 +60,10 @@ def logpdf(self) -> Callable: return jax.scipy.stats.poisson.logpmf @jax.named_scope("evm.loss.PoissonNLL") - def __call__( - self, expectation: Array, observation: Array, constraint: Array - ) -> Array: + def __call__(self, expectation: Array, observation: Array) -> Array: # poisson log-likelihood - nll = jnp.sum( + return jnp.sum( self.logpdf(observation, expectation) - self.logpdf(observation, observation), axis=-1, ) - # add constraint - nll += constraint - return -jnp.sum(nll) diff --git a/src/evermore/sample.py b/src/evermore/sample.py index 43fc87d..5ea6632 100644 --- a/src/evermore/sample.py +++ b/src/evermore/sample.py @@ -20,7 +20,7 @@ def toy_module(module: eqx.Module, key: PRNGKeyArray) -> PyTree[Callable]: def _sample(param: Parameter, key: Parameter) -> Array: if not param.constraints: - msg = f"Parameter {param} has no constraint pdf, can't sample from it." + msg = f"Parameter {param} has no constraint pdf, can't sample from it. Maybe you need to call the model once to populate all constraints?" raise RuntimeError(msg) if len(param.constraints) > 1: msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}" From 6202f2f83dcc4dadbacd46a29cb9477ddc1f9643 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 6 Mar 2024 15:34:00 +0100 Subject: [PATCH 07/22] explicitely add jaxlib as dep --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 168df55..fc5c765 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ dynamic = ["version"] # version is set in src/evermore/__init__.py dependencies = [ "jax", + "jaxlib", "jaxtyping", "equinox>=0.10.6", # eqx.field ] From e3380ee0ea779255f8a3f6e5721adbd700b7a25c Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Thu, 7 Mar 2024 15:39:02 +0100 Subject: [PATCH 08/22] add deps to dev env --- pixi.lock | 3002 ++++++++++++++++++++++++++++++++++++++++++++++++++++- pixi.toml | 3 + 2 files changed, 2967 insertions(+), 38 deletions(-) diff --git a/pixi.lock b/pixi.lock index 86d6418..2686974 100644 --- a/pixi.lock +++ b/pixi.lock @@ -7,55 +7,157 @@ environments: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.11-hd590300_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hd590300_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.1.0-hd590300_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.1.0-py311hb755f60_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hd590300_5.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/c-ares-1.27.0-hd590300_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2024.2.2-hbcca054_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py311h9547e67_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/linux-64/docutils-0.20.1-py311h38be061_3.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.49.0-py311h459d7ec_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.4-hfc55251_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.4-hfc55251_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.9-h8e1006c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.9-h98fc4e7_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.3.0-h3d44ed6_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/icu-73.2-h59595ed_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/jaxlib-0.4.23-cpu_py311hc0fb0b9_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.1-h166bdaf_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py311h9547e67_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/lame-3.100-h166bdaf_1003.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.16-hb7c19ff_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.40-h41732ed_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/lerc-4.0.0-h27087fc_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/linux-64/libabseil-20230802.1-cxx17_h59595ed_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-21_linux64_openblas.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.1.0-hd590300_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.1.0-hd590300_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.1.0-hd590300_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libcap-2.69-h0f662aa_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.9.0-21_linux64_openblas.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_hb11cfb5_4.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_ha2b6cf4_4.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h4637d8d_4.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.19-hd590300_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20191231-he28a2e2_2.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.3-h59595ed_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-13.2.0-h807b86a_5.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.3-hd590300_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgfortran-ng-13.2.0-h69a702a_5.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgfortran5-13.2.0-ha4646dd_5.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.4-h783c2da_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgomp-13.2.0-h807b86a_5.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.48-h71f35ed_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libgrpc-1.59.3-hd6c4280_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-hd590300_2.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-3.0.0-hd590300_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-21_linux64_openblas.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-hb3ce162_4.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libogg-1.3.4-h7f98852_1.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.26-pthreads_h413a1c8_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libopus-1.3.1-h7f98852_1.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.43-h2797004_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libpq-16.2-h33b98f1_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-4.24.4-hf27288f_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libre2-11-2023.06.02-h7a70373_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.45.1-h2797004_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_5.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-255-h3516f8a_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.6.0-hd429924_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.12.5-h232c23b_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.13-hd590300_5.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/markupsafe-2.1.5-py311h459d7ec_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.3-py311h38be061_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.3-py311h54ef318_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ml_dtypes-0.3.2-py311h320fe9a_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.4-h59595ed_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_6.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_6.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.4-h59595ed_2.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/nss-3.98-h1d7d5a4_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py311h64a7726_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.2-h488ebb8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.2.1-hd590300_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.42-hcad00b1_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-10.2.0-py311ha6c5da5_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.43.2-h59595ed_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-hb77b528_5.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.9-py311hf0fb5b6_5.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.12.2-py311hb755f60_5.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.11.0-he550d4f_1_cpython.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.11-4_cp311.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.1-py311h459d7ec_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h5810be5_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/re2-2023.06.02-h2873b5e_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/scipy-1.12.0-py311h64a7726_2.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.12-py311hb755f60_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.4-py311h459d7ec_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-hd590300_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-image-0.4.0-h8ee46fc_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.0-h8ee46fc_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.9-hd590300_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.1-h8ee46fc_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.41-hd590300_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-kbproto-1.0.7-h7f98852_1002.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libice-1.1.1-hd590300_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.11-hd590300_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxdmcp-1.1.3-h7f98852_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.4-h0b41bf4_2.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.11-hd590300_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-renderproto-0.11.1-h7f98852_1002.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-xextproto-7.3.0-h0b41bf4_1003.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-xf86vidmodeproto-2.3.1-h7f98852_1002.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-xproto-7.0.31-h7f98852_1007.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/linux-64/yaml-0.2.5-h7f98852_2.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-0.7.16-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/asttokens-2.4.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/babel-2.14.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/certifi-2024.2.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/decorator-5.1.1-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/executing-2.0.1-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-h77eed37_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/idna-3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/imageio-2.34.0-pyh4b66e23_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-7.0.1-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-7.0.1-hd8ed1ab_0.conda @@ -67,6 +169,9 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/mdit-py-plugins-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/mplhep-0.3.35-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/mplhep_data-0.0.3-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/myst-parser-2.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt-einsum-3.3.0-hd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt_einsum-3.3.0-pyhc1e730c_2.conda @@ -75,11 +180,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pickleshare-0.7.5-py_1003.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pip-24.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/prompt-toolkit-3.0.42-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd3deb0d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.2-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.1-pyhd8ed1ab_0.conda @@ -93,9 +201,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-qthelp-1.0.7-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/stack_data-0.6.2-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.9.0-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2024a-h0c530f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/uhi-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.2.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/wcwidth-0.2.13-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/wheel-0.42.0-pyhd8ed1ab_0.conda @@ -107,10 +218,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/certifi-2024.2.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/decorator-5.1.1-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/executing-2.0.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/idna-3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/imageio-2.34.0-pyh4b66e23_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-7.0.1-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-7.0.1-hd8ed1ab_0.conda @@ -122,6 +235,9 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/mdit-py-plugins-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/mplhep-0.3.35-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/mplhep_data-0.0.3-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/myst-parser-2.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt-einsum-3.3.0-hd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt_einsum-3.3.0-pyhc1e730c_2.conda @@ -134,7 +250,9 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd3deb0d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.2-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.1-pyhd8ed1ab_0.conda @@ -151,36 +269,59 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.9.0-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2024a-h0c530f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/uhi-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.2.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/wcwidth-0.2.13-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/wheel-0.42.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/brotli-1.1.0-h0dc2134_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/brotli-bin-1.1.0-h0dc2134_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/brotli-python-1.1.0-py311hdf8f085_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/bzip2-1.0.8-h10d778d_5.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/c-ares-1.27.0-h10d778d_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/ca-certificates-2024.2.2-h8857fd0_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/contourpy-1.2.0-py311h7bea37d_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/docutils-0.20.1-py311h6eed73b_3.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/fonttools-4.49.0-py311he705e18_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/freetype-2.12.1-h60636b9_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/jaxlib-0.4.23-cpu_py311h7977596_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/kiwisolver-1.4.5-py311h5fe6e05_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/lcms2-2.16-ha2f27b4_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/lerc-4.0.0-hb486fe8_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-64/libabseil-20230802.1-cxx17_h048a20a_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libblas-3.9.0-21_osx64_openblas.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libbrotlicommon-1.1.0-h0dc2134_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libbrotlidec-1.1.0-h0dc2134_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libbrotlienc-1.1.0-h0dc2134_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libcblas-3.9.0-21_osx64_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libcxx-16.0.6-hd57cbcb_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libdeflate-1.19-ha4e1b8e_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libffi-3.4.2-h0d85af4_5.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-64/libgfortran-5.0.0-13_2_0_h97931a8_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libgfortran5-13.2.0-h2873a65_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libgrpc-1.59.3-ha7f534c_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libjpeg-turbo-3.0.0-h0dc2134_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/liblapack-3.9.0-21_osx64_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libopenblas-0.3.26-openmp_hfef2a42_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libpng-1.6.43-h92b6c6a_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libprotobuf-4.24.4-hc4f2305_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libre2-11-2023.06.02-h4694dbf_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libsqlite-3.45.1-h92b6c6a_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libtiff-4.6.0-h684deea_2.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libwebp-base-1.3.2-h0dc2134_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libxcb-1.15-hb7f2c08_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libzlib-1.2.13-h8a1eda9_5.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/llvm-openmp-17.0.6-hb6ac08f_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/markupsafe-2.1.5-py311he705e18_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/matplotlib-3.8.3-py311h6eed73b_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/matplotlib-base-3.8.3-py311h6ff1f5f_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/ml_dtypes-0.3.2-py311h8f6166a_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/ncurses-6.4-h93d8f39_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/numpy-1.26.4-py311hc43a94b_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/openjpeg-2.5.2-h7310d3a_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/openssl-3.2.1-hd75f5a5_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/pillow-10.2.0-py311hea5c87a_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/pthread-stubs-0.4-hc929b4f_1001.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-64/python-3.11.0-he7542f4_1_cpython.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/python_abi-3.11-4_cp311.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/pyyaml-6.0.1-py311h2725bcf_1.conda @@ -188,8 +329,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/readline-8.2-h9e318b2_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/scipy-1.12.0-py311h86d0cd9_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/tk-8.6.13-h1abcd95_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/tornado-6.4-py311he705e18_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/xorg-libxau-1.0.11-h0dc2134_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/xorg-libxdmcp-1.1.3-h35c211d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-64/xz-5.2.6-h775f41a_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-64/yaml-0.2.5-h0d85af4_2.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/osx-64/zstd-1.5.5-h829000d_0.conda osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/alabaster-0.7.16-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/asttokens-2.4.1-pyhd8ed1ab_0.conda @@ -197,10 +342,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/certifi-2024.2.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.3.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/decorator-5.1.1-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/executing-2.0.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/idna-3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/imageio-2.34.0-pyh4b66e23_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-7.0.1-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-7.0.1-hd8ed1ab_0.conda @@ -212,6 +359,9 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/mdit-py-plugins-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/mplhep-0.3.35-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/mplhep_data-0.0.3-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/myst-parser-2.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt-einsum-3.3.0-hd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt_einsum-3.3.0-pyhc1e730c_2.conda @@ -224,7 +374,9 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd3deb0d_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.2-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.17.2-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pysocks-1.7.1-pyha2e5f31_6.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/requests-2.31.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/setuptools-69.1.1-pyhd8ed1ab_0.conda @@ -241,36 +393,59 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.9.0-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2024a-h0c530f3_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/uhi-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.2.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/wcwidth-0.2.13-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/wheel-0.42.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-3.17.0-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-1.1.0-hb547adb_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-bin-1.1.0-hb547adb_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-python-1.1.0-py311ha891d26_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/bzip2-1.0.8-h93a5062_5.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/c-ares-1.27.0-h93a5062_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ca-certificates-2024.2.2-hf0a4a13_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/contourpy-1.2.0-py311hd03642b_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/docutils-0.20.1-py311h267d04e_3.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/fonttools-4.48.1-py311h05b510d_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/freetype-2.12.1-hadb7bae_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/jaxlib-0.4.23-cpu_py311h95f3fdf_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/kiwisolver-1.4.5-py311he4fd1f5_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lcms2-2.16-ha0e7c42_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lerc-4.0.0-h9a09cb3_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libabseil-20230802.1-cxx17_h13dd4ca_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libblas-3.9.0-21_osxarm64_openblas.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libbrotlicommon-1.1.0-hb547adb_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libbrotlidec-1.1.0-hb547adb_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libbrotlienc-1.1.0-hb547adb_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcblas-3.9.0-21_osxarm64_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcxx-16.0.6-h4653b0c_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libdeflate-1.19-hb547adb_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libffi-3.4.2-h3422bc3_5.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran-5.0.0-13_2_0_hd922786_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran5-13.2.0-hf226fd6_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgrpc-1.59.3-h9560976_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libjpeg-turbo-3.0.0-hb547adb_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/liblapack-3.9.0-21_osxarm64_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libopenblas-0.3.26-openmp_h6c19121_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libpng-1.6.43-h091b4b1_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libprotobuf-4.24.4-h810fc01_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libre2-11-2023.06.02-h1753957_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.45.1-h091b4b1_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libtiff-4.6.0-ha8a6c65_2.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libwebp-base-1.3.2-hb547adb_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libxcb-1.15-hf346824_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libzlib-1.2.13-h53f4e23_5.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/llvm-openmp-17.0.6-hcd81f8e_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/markupsafe-2.1.5-py311h05b510d_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/matplotlib-3.8.3-py311ha1ab1f8_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/matplotlib-base-3.8.3-py311hb58f1d1_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ml_dtypes-0.3.2-py311hfbe21a1_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.4-h463b476_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-1.26.4-py311h7125741_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.2-h9f1df11_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.2.1-h0d3ecfb_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-10.2.0-py311hb9c5795_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-h27ca646_1001.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.11.0-h3ba56d0_1_cpython.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python_abi-3.11-4_cp311.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyyaml-6.0.1-py311heffc1b2_1.conda @@ -278,8 +453,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/scipy-1.12.0-py311h4f9446f_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tornado-6.4-py311h05b510d_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxau-1.0.11-hb547adb_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxdmcp-1.1.3-h27ca646_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/yaml-0.2.5-h3422bc3_2.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.5-h4f39d0f_0.conda packages: - kind: conda name: _libgcc_mutex @@ -325,6 +504,21 @@ packages: license_family: BSD size: 18365 timestamp: 1704848898483 +- kind: conda + name: alsa-lib + version: 1.2.11 + build: hd590300_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.11-hd590300_1.conda + sha256: 0e2b75b9834a6e520b13db516f7cf5c9cea8f0bbc9157c978444173dacb98fec + md5: 0bb492cca54017ea314b809b1ee3a176 + depends: + - libgcc-ng >=12 + license: LGPL-2.1-or-later + license_family: GPL + size: 554699 + timestamp: 1709396557528 - kind: conda name: asttokens version: 2.4.1 @@ -341,6 +535,21 @@ packages: license_family: Apache size: 28922 timestamp: 1698341257884 +- kind: conda + name: attr + version: 2.5.1 + build: h166bdaf_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2 + sha256: 82c13b1772c21fc4a17441734de471d3aabf82b61db9b11f4a1bd04a9c4ac324 + md5: d9c69a24ad678ffce24c6543a0176b00 + depends: + - libgcc-ng >=12 + license: GPL-2.0-or-later + license_family: GPL + size: 71042 + timestamp: 1660065501192 - kind: conda name: babel version: 2.14.0 @@ -358,6 +567,107 @@ packages: license_family: BSD size: 7609750 timestamp: 1702422720584 +- kind: conda + name: brotli + version: 1.1.0 + build: h0dc2134_1 + build_number: 1 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/brotli-1.1.0-h0dc2134_1.conda + sha256: 4bf66d450be5d3f9ebe029b50f818d088b1ef9666b1f19e90c85479c77bbdcde + md5: 9272dd3b19c4e8212f8542cefd5c3d67 + depends: + - brotli-bin 1.1.0 h0dc2134_1 + - libbrotlidec 1.1.0 h0dc2134_1 + - libbrotlienc 1.1.0 h0dc2134_1 + license: MIT + license_family: MIT + size: 19530 + timestamp: 1695990310168 +- kind: conda + name: brotli + version: 1.1.0 + build: hb547adb_1 + build_number: 1 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-1.1.0-hb547adb_1.conda + sha256: 62d1587deab752fcee07adc371eb20fcadc09f72c0c85399c22b637ca858020f + md5: a33aa58d448cbc054f887e39dd1dfaea + depends: + - brotli-bin 1.1.0 hb547adb_1 + - libbrotlidec 1.1.0 hb547adb_1 + - libbrotlienc 1.1.0 hb547adb_1 + license: MIT + license_family: MIT + size: 19506 + timestamp: 1695990588610 +- kind: conda + name: brotli + version: 1.1.0 + build: hd590300_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hd590300_1.conda + sha256: f2d918d351edd06c55a6c2d84b488fe392f85ea018ff227daac07db22b408f6b + md5: f27a24d46e3ea7b70a1f98e50c62508f + depends: + - brotli-bin 1.1.0 hd590300_1 + - libbrotlidec 1.1.0 hd590300_1 + - libbrotlienc 1.1.0 hd590300_1 + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 19383 + timestamp: 1695990069230 +- kind: conda + name: brotli-bin + version: 1.1.0 + build: h0dc2134_1 + build_number: 1 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/brotli-bin-1.1.0-h0dc2134_1.conda + sha256: 7ca3cfb4c5df314ed481301335387ab2b2ee651e2c74fbb15bacc795c664a5f1 + md5: ece565c215adcc47fc1db4e651ee094b + depends: + - libbrotlidec 1.1.0 h0dc2134_1 + - libbrotlienc 1.1.0 h0dc2134_1 + license: MIT + license_family: MIT + size: 16660 + timestamp: 1695990286737 +- kind: conda + name: brotli-bin + version: 1.1.0 + build: hb547adb_1 + build_number: 1 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-bin-1.1.0-hb547adb_1.conda + sha256: 8fbfc2834606292016f2faffac67deea4c5cdbc21a61169f0b355e1600105a24 + md5: 990d04f8c017b1b77103f9a7730a5f12 + depends: + - libbrotlidec 1.1.0 hb547adb_1 + - libbrotlienc 1.1.0 hb547adb_1 + license: MIT + license_family: MIT + size: 17001 + timestamp: 1695990551239 +- kind: conda + name: brotli-bin + version: 1.1.0 + build: hd590300_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.1.0-hd590300_1.conda + sha256: a641abfbaec54f454c8434061fffa7fdaa9c695e8a5a400ed96b4f07c0c00677 + md5: 39f910d205726805a958da408ca194ba + depends: + - libbrotlidec 1.1.0 hd590300_1 + - libbrotlienc 1.1.0 hd590300_1 + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 18980 + timestamp: 1695990054140 - kind: conda name: brotli-python version: 1.1.0 @@ -529,6 +839,35 @@ packages: license: ISC size: 155725 timestamp: 1706844034242 +- kind: conda + name: cairo + version: 1.18.0 + build: h3faef2a_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda + sha256: 142e2639a5bc0e99c44d76f4cc8dce9c6a2d87330c4beeabb128832cd871a86e + md5: f907bb958910dc404647326ca80c263e + depends: + - fontconfig >=2.14.2,<3.0a0 + - fonts-conda-ecosystem + - freetype >=2.12.1,<3.0a0 + - icu >=73.2,<74.0a0 + - libgcc-ng >=12 + - libglib >=2.78.0,<3.0a0 + - libpng >=1.6.39,<1.7.0a0 + - libstdcxx-ng >=12 + - libxcb >=1.15,<1.16.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - pixman >=0.42.2,<1.0a0 + - xorg-libice >=1.1.1,<2.0a0 + - xorg-libsm >=1.2.4,<2.0a0 + - xorg-libx11 >=1.8.6,<2.0a0 + - xorg-libxext >=1.3.4,<2.0a0 + - xorg-libxrender >=0.9.11,<0.10.0a0 + - zlib + license: LGPL-2.1-only or MPL-1.1 + size: 982351 + timestamp: 1697028423052 - kind: conda name: certifi version: 2024.2.2 @@ -573,6 +912,93 @@ packages: license_family: BSD size: 25170 timestamp: 1666700778190 +- kind: conda + name: contourpy + version: 1.2.0 + build: py311h7bea37d_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/contourpy-1.2.0-py311h7bea37d_0.conda + sha256: 40bca4a644e0c0b0e6d58cef849ba02d4f218af715f7a5787d41845797f3b8a9 + md5: 6711c052d956af4973a16749236a0387 + depends: + - __osx >=10.9 + - libcxx >=16.0.6 + - numpy >=1.20,<2 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: BSD-3-Clause + license_family: BSD + size: 248078 + timestamp: 1699042040747 +- kind: conda + name: contourpy + version: 1.2.0 + build: py311h9547e67_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py311h9547e67_0.conda + sha256: 2c76e2a970b74eef92ef9460aa705dbdc506dd59b7382bfbedce39d9c189d7f4 + md5: 40828c5b36ef52433e21f89943e09f33 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - numpy >=1.20,<2 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: BSD-3-Clause + license_family: BSD + size: 255843 + timestamp: 1699041590533 +- kind: conda + name: contourpy + version: 1.2.0 + build: py311hd03642b_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/contourpy-1.2.0-py311hd03642b_0.conda + sha256: 3ec341c3a33bbb7f60e9a96214e0e08c4ba9e4a553b18104194e7843abbb4ef4 + md5: c0fa0bea0af7ecdea23bf983655fa2d0 + depends: + - __osx >=10.9 + - libcxx >=16.0.6 + - numpy >=1.20,<2 + - python >=3.11,<3.12.0a0 + - python >=3.11,<3.12.0a0 *_cpython + - python_abi 3.11.* *_cp311 + license: BSD-3-Clause + license_family: BSD + size: 240223 + timestamp: 1699041881051 +- kind: conda + name: cycler + version: 0.12.1 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda + sha256: f221233f21b1d06971792d491445fd548224641af9443739b4b7b6d5d72954a8 + md5: 5cd86562580f274031ede6aa6aa24441 + depends: + - python >=3.8 + license: BSD-3-Clause + license_family: BSD + size: 13458 + timestamp: 1696677888423 +- kind: conda + name: dbus + version: 1.13.6 + build: h5008d03_3 + build_number: 3 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2 + sha256: 8f5f995699a2d9dbdd62c61385bfeeb57c82a681a7c8c5313c395aa0ccab68a5 + md5: ecfff944ba3960ecb334b9a2663d708d + depends: + - expat >=2.4.2,<3.0a0 + - libgcc-ng >=9.4.0 + - libglib >=2.70.2,<3.0a0 + license: GPL-2.0-or-later + license_family: GPL + size: 618596 + timestamp: 1640112124844 - kind: conda name: decorator version: 5.1.1 @@ -664,6 +1090,374 @@ packages: license_family: MIT size: 27689 timestamp: 1698580072627 +- kind: conda + name: expat + version: 2.5.0 + build: hcb278e6_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/expat-2.5.0-hcb278e6_1.conda + sha256: 36dfeb4375059b3bba75ce9b38c29c69fd257342a79e6cf20e9f25c1523f785f + md5: 8b9b5aca60558d02ddaa09d599e55920 + depends: + - libexpat 2.5.0 hcb278e6_1 + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 136778 + timestamp: 1680190541750 +- kind: conda + name: font-ttf-dejavu-sans-mono + version: '2.37' + build: hab24e00_0 + subdir: noarch + noarch: generic + url: https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2 + sha256: 58d7f40d2940dd0a8aa28651239adbf5613254df0f75789919c4e6762054403b + md5: 0c96522c6bdaed4b1566d11387caaf45 + license: BSD-3-Clause + license_family: BSD + size: 397370 + timestamp: 1566932522327 +- kind: conda + name: font-ttf-inconsolata + version: '3.000' + build: h77eed37_0 + subdir: noarch + noarch: generic + url: https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2 + sha256: c52a29fdac682c20d252facc50f01e7c2e7ceac52aa9817aaf0bb83f7559ec5c + md5: 34893075a5c9e55cdafac56607368fc6 + license: OFL-1.1 + license_family: Other + size: 96530 + timestamp: 1620479909603 +- kind: conda + name: font-ttf-source-code-pro + version: '2.038' + build: h77eed37_0 + subdir: noarch + noarch: generic + url: https://conda.anaconda.org/conda-forge/noarch/font-ttf-source-code-pro-2.038-h77eed37_0.tar.bz2 + sha256: 00925c8c055a2275614b4d983e1df637245e19058d79fc7dd1a93b8d9fb4b139 + md5: 4d59c254e01d9cde7957100457e2d5fb + license: OFL-1.1 + license_family: Other + size: 700814 + timestamp: 1620479612257 +- kind: conda + name: font-ttf-ubuntu + version: '0.83' + build: h77eed37_1 + build_number: 1 + subdir: noarch + noarch: generic + url: https://conda.anaconda.org/conda-forge/noarch/font-ttf-ubuntu-0.83-h77eed37_1.conda + sha256: 056c85b482d58faab5fd4670b6c1f5df0986314cca3bc831d458b22e4ef2c792 + md5: 6185f640c43843e5ad6fd1c5372c3f80 + license: LicenseRef-Ubuntu-Font-Licence-Version-1.0 + license_family: Other + size: 1619820 + timestamp: 1700944216729 +- kind: conda + name: fontconfig + version: 2.14.2 + build: h14ed4e7_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda + sha256: 155d534c9037347ea7439a2c6da7c24ffec8e5dd278889b4c57274a1d91e0a83 + md5: 0f69b688f52ff6da70bccb7ff7001d1d + depends: + - expat >=2.5.0,<3.0a0 + - freetype >=2.12.1,<3.0a0 + - libgcc-ng >=12 + - libuuid >=2.32.1,<3.0a0 + - libzlib >=1.2.13,<1.3.0a0 + license: MIT + license_family: MIT + size: 272010 + timestamp: 1674828850194 +- kind: conda + name: fonts-conda-ecosystem + version: '1' + build: '0' + subdir: noarch + noarch: generic + url: https://conda.anaconda.org/conda-forge/noarch/fonts-conda-ecosystem-1-0.tar.bz2 + sha256: a997f2f1921bb9c9d76e6fa2f6b408b7fa549edd349a77639c9fe7a23ea93e61 + md5: fee5683a3f04bd15cbd8318b096a27ab + depends: + - fonts-conda-forge + license: BSD-3-Clause + license_family: BSD + size: 3667 + timestamp: 1566974674465 +- kind: conda + name: fonts-conda-forge + version: '1' + build: '0' + subdir: noarch + noarch: generic + url: https://conda.anaconda.org/conda-forge/noarch/fonts-conda-forge-1-0.tar.bz2 + sha256: 53f23a3319466053818540bcdf2091f253cbdbab1e0e9ae7b9e509dcaa2a5e38 + md5: f766549260d6815b0c52253f1fb1bb29 + depends: + - font-ttf-dejavu-sans-mono + - font-ttf-inconsolata + - font-ttf-source-code-pro + - font-ttf-ubuntu + license: BSD-3-Clause + license_family: BSD + size: 4102 + timestamp: 1566932280397 +- kind: conda + name: fonttools + version: 4.48.1 + build: py311h05b510d_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/fonttools-4.48.1-py311h05b510d_0.conda + sha256: f1b89183449690090b6f187eca93d978184b4aedce5626f946135dcfc7072f3b + md5: d82b00861d6f37cece9a7925359a9faa + depends: + - brotli + - munkres + - python >=3.11,<3.12.0a0 + - python >=3.11,<3.12.0a0 *_cpython + - python_abi 3.11.* *_cp311 + license: MIT + license_family: MIT + size: 2762067 + timestamp: 1707240469460 +- kind: conda + name: fonttools + version: 4.49.0 + build: py311h459d7ec_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.49.0-py311h459d7ec_0.conda + sha256: bbf00a8da6c109cb139dd1e691052081e7e1e28ff2a849e7297c9e71588a6d6f + md5: d66c9e36ab104f94e35b015c86c2fcb4 + depends: + - brotli + - libgcc-ng >=12 + - munkres + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: MIT + license_family: MIT + size: 2817685 + timestamp: 1708049363863 +- kind: conda + name: fonttools + version: 4.49.0 + build: py311he705e18_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/fonttools-4.49.0-py311he705e18_0.conda + sha256: 8ac8c8836616dcf366fd539951367d1e0f3a0f3e519287b3218665cb37366bfc + md5: fc14300cb29ba11efaaa294b3efb14e0 + depends: + - brotli + - munkres + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: MIT + license_family: MIT + size: 2745147 + timestamp: 1708049594531 +- kind: conda + name: freetype + version: 2.12.1 + build: h267a509_2 + build_number: 2 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda + sha256: b2e3c449ec9d907dd4656cb0dc93e140f447175b125a3824b31368b06c666bb6 + md5: 9ae35c3d96db2c94ce0cef86efdfa2cb + depends: + - libgcc-ng >=12 + - libpng >=1.6.39,<1.7.0a0 + - libzlib >=1.2.13,<1.3.0a0 + license: GPL-2.0-only OR FTL + size: 634972 + timestamp: 1694615932610 +- kind: conda + name: freetype + version: 2.12.1 + build: h60636b9_2 + build_number: 2 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/freetype-2.12.1-h60636b9_2.conda + sha256: b292cf5a25f094eeb4b66e37d99a97894aafd04a5683980852a8cbddccdc8e4e + md5: 25152fce119320c980e5470e64834b50 + depends: + - libpng >=1.6.39,<1.7.0a0 + - libzlib >=1.2.13,<1.3.0a0 + license: GPL-2.0-only OR FTL + size: 599300 + timestamp: 1694616137838 +- kind: conda + name: freetype + version: 2.12.1 + build: hadb7bae_2 + build_number: 2 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/freetype-2.12.1-hadb7bae_2.conda + sha256: 791673127e037a2dc0eebe122dc4f904cb3f6e635bb888f42cbe1a76b48748d9 + md5: e6085e516a3e304ce41a8ee08b9b89ad + depends: + - libpng >=1.6.39,<1.7.0a0 + - libzlib >=1.2.13,<1.3.0a0 + license: GPL-2.0-only OR FTL + size: 596430 + timestamp: 1694616332835 +- kind: conda + name: gettext + version: 0.21.1 + build: h27087fc_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/gettext-0.21.1-h27087fc_0.tar.bz2 + sha256: 4fcfedc44e4c9a053f0416f9fc6ab6ed50644fca3a761126dbd00d09db1f546a + md5: 14947d8770185e5153fdd04d4673ed37 + depends: + - libgcc-ng >=12 + license: LGPL-2.1-or-later AND GPL-3.0-or-later + size: 4320628 + timestamp: 1665673494324 +- kind: conda + name: glib + version: 2.78.4 + build: hfc55251_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.4-hfc55251_0.conda + sha256: 316c95dcbde46b7418d2b667a7e0c1d05101b673cd8c691d78d8699600a07a5b + md5: f36a7b2420c3fc3c48a3d609841d8fee + depends: + - gettext >=0.21.1,<1.0a0 + - glib-tools 2.78.4 hfc55251_0 + - libgcc-ng >=12 + - libglib 2.78.4 h783c2da_0 + - libstdcxx-ng >=12 + - libzlib >=1.2.13,<1.3.0a0 + - python * + license: LGPL-2.1-or-later + size: 489127 + timestamp: 1708284952839 +- kind: conda + name: glib-tools + version: 2.78.4 + build: hfc55251_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.4-hfc55251_0.conda + sha256: e94494b895f77ba54922ffb1dcfb7f1a987591b823eb5ce608afb2e2391d7d82 + md5: d184ba1bf15a2bbb3be6118c90fd487d + depends: + - libgcc-ng >=12 + - libglib 2.78.4 h783c2da_0 + - libstdcxx-ng >=12 + - libzlib >=1.2.13,<1.3.0a0 + license: LGPL-2.1-or-later + size: 111383 + timestamp: 1708284914557 +- kind: conda + name: graphite2 + version: 1.3.13 + build: h58526e2_1001 + build_number: 1001 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/graphite2-1.3.13-h58526e2_1001.tar.bz2 + sha256: 65da967f3101b737b08222de6a6a14e20e480e7d523a5d1e19ace7b960b5d6b1 + md5: 8c54672728e8ec6aa6db90cf2806d220 + depends: + - libgcc-ng >=7.5.0 + - libstdcxx-ng >=7.5.0 + license: LGPLv2 + size: 104701 + timestamp: 1604365484436 +- kind: conda + name: gst-plugins-base + version: 1.22.9 + build: h8e1006c_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/gst-plugins-base-1.22.9-h8e1006c_0.conda + sha256: a4312c96a670fdbf9ff0c3efd935e42fa4b655ff33dcc52c309b76a2afaf03f0 + md5: 614b81f8ed66c56b640faee7076ad14a + depends: + - __glibc >=2.17,<3.0.a0 + - alsa-lib >=1.2.10,<1.3.0.0a0 + - gettext >=0.21.1,<1.0a0 + - gstreamer 1.22.9 h98fc4e7_0 + - libexpat >=2.5.0,<3.0a0 + - libgcc-ng >=12 + - libglib >=2.78.3,<3.0a0 + - libogg >=1.3.4,<1.4.0a0 + - libopus >=1.3.1,<2.0a0 + - libpng >=1.6.39,<1.7.0a0 + - libstdcxx-ng >=12 + - libvorbis >=1.3.7,<1.4.0a0 + - libxcb >=1.15,<1.16.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - xorg-libx11 >=1.8.7,<2.0a0 + - xorg-libxau >=1.0.11,<2.0a0 + - xorg-libxext >=1.3.4,<2.0a0 + - xorg-libxrender >=0.9.11,<0.10.0a0 + license: LGPL-2.0-or-later + license_family: LGPL + size: 2709696 + timestamp: 1706154948546 +- kind: conda + name: gstreamer + version: 1.22.9 + build: h98fc4e7_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/gstreamer-1.22.9-h98fc4e7_0.conda + sha256: aa2395bf1790f72d2706bac77430f765ec1318ca22e60e791c13ae452c045263 + md5: bcc7157b06fce7f5e055402a8135dfd8 + depends: + - __glibc >=2.17,<3.0.a0 + - gettext >=0.21.1,<1.0a0 + - glib >=2.78.3,<3.0a0 + - libgcc-ng >=12 + - libglib >=2.78.3,<3.0a0 + - libiconv >=1.17,<2.0a0 + - libstdcxx-ng >=12 + license: LGPL-2.0-or-later + license_family: LGPL + size: 1981554 + timestamp: 1706154826325 +- kind: conda + name: harfbuzz + version: 8.3.0 + build: h3d44ed6_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/harfbuzz-8.3.0-h3d44ed6_0.conda + sha256: 4b55aea03b18a4084b750eee531ad978d4a3690f63019132c26c6ad26bbe3aed + md5: 5a6f6c00ef982a9bc83558d9ac8f64a0 + depends: + - cairo >=1.18.0,<2.0a0 + - freetype >=2.12.1,<3.0a0 + - graphite2 + - icu >=73.2,<74.0a0 + - libgcc-ng >=12 + - libglib >=2.78.1,<3.0a0 + - libstdcxx-ng >=12 + license: MIT + license_family: MIT + size: 1547473 + timestamp: 1699925311766 +- kind: conda + name: icu + version: '73.2' + build: h59595ed_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/icu-73.2-h59595ed_0.conda + sha256: e12fd90ef6601da2875ebc432452590bc82a893041473bc1c13ef29001a73ea8 + md5: cc47e1facc155f91abd89b11e48e72ff + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + license: MIT + license_family: MIT + size: 12089150 + timestamp: 1692900650789 - kind: conda name: idna version: '3.6' @@ -679,6 +1473,23 @@ packages: license_family: BSD size: 50124 timestamp: 1701027126206 +- kind: conda + name: imageio + version: 2.34.0 + build: pyh4b66e23_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/imageio-2.34.0-pyh4b66e23_0.conda + sha256: be0eecc8b3ee49ffe3c38dedc4d3c121e18627624926f7d1d998e8027bce4266 + md5: b8853659d596f967c661f544dd89ede7 + depends: + - numpy + - pillow >=8.3.2 + - python >=3 + license: BSD-2-Clause + license_family: BSD + size: 290617 + timestamp: 1707730229565 - kind: conda name: imagesize version: 1.4.1 @@ -888,6 +1699,152 @@ packages: license_family: BSD size: 111589 timestamp: 1704967140287 +- kind: conda + name: keyutils + version: 1.6.1 + build: h166bdaf_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.1-h166bdaf_0.tar.bz2 + sha256: 150c05a6e538610ca7c43beb3a40d65c90537497a4f6a5f4d15ec0451b6f5ebb + md5: 30186d27e2c9fa62b45fb1476b7200e3 + depends: + - libgcc-ng >=10.3.0 + license: LGPL-2.1-or-later + size: 117831 + timestamp: 1646151697040 +- kind: conda + name: kiwisolver + version: 1.4.5 + build: py311h5fe6e05_1 + build_number: 1 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/kiwisolver-1.4.5-py311h5fe6e05_1.conda + sha256: 586a4d0a17e6cfd9f8fdee56106d263ee40ca156832774d6e899f82ad68ac8d0 + md5: 24305b23f7995de72bbd53b7c01242a2 + depends: + - libcxx >=15.0.7 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: BSD-3-Clause + license_family: BSD + size: 60694 + timestamp: 1695380246398 +- kind: conda + name: kiwisolver + version: 1.4.5 + build: py311h9547e67_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py311h9547e67_1.conda + sha256: 723b0894d2d2b05a38f9c5a285d5a0a5baa27235ceab6531dbf262ba7c6955c1 + md5: 2c65bdf442b0d37aad080c8a4e0d452f + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: BSD-3-Clause + license_family: BSD + size: 73273 + timestamp: 1695380140676 +- kind: conda + name: kiwisolver + version: 1.4.5 + build: py311he4fd1f5_1 + build_number: 1 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/kiwisolver-1.4.5-py311he4fd1f5_1.conda + sha256: 907af50734789d47b3e8b2148dde763699dc746c64e5849baf6bd720c8cd0235 + md5: 4c871d65040b8c7bbb914df7f8f11492 + depends: + - libcxx >=15.0.7 + - python >=3.11,<3.12.0a0 + - python >=3.11,<3.12.0a0 *_cpython + - python_abi 3.11.* *_cp311 + license: BSD-3-Clause + license_family: BSD + size: 61946 + timestamp: 1695380538042 +- kind: conda + name: krb5 + version: 1.21.2 + build: h659d440_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda + sha256: 259bfaae731989b252b7d2228c1330ef91b641c9d68ff87dae02cbae682cb3e4 + md5: cd95826dbd331ed1be26bdf401432844 + depends: + - keyutils >=1.6.1,<2.0a0 + - libedit >=3.1.20191231,<3.2.0a0 + - libedit >=3.1.20191231,<4.0a0 + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - openssl >=3.1.2,<4.0a0 + license: MIT + license_family: MIT + size: 1371181 + timestamp: 1692097755782 +- kind: conda + name: lame + version: '3.100' + build: h166bdaf_1003 + build_number: 1003 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/lame-3.100-h166bdaf_1003.tar.bz2 + sha256: aad2a703b9d7b038c0f745b853c6bb5f122988fe1a7a096e0e606d9cbec4eaab + md5: a8832b479f93521a9e7b5b743803be51 + depends: + - libgcc-ng >=12 + license: LGPL-2.0-only + license_family: LGPL + size: 508258 + timestamp: 1664996250081 +- kind: conda + name: lcms2 + version: '2.16' + build: ha0e7c42_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/lcms2-2.16-ha0e7c42_0.conda + sha256: 151e0c84feb7e0747fabcc85006b8973b22f5abbc3af76a9add0b0ef0320ebe4 + md5: 66f6c134e76fe13cce8a9ea5814b5dd5 + depends: + - libjpeg-turbo >=3.0.0,<4.0a0 + - libtiff >=4.6.0,<4.7.0a0 + license: MIT + license_family: MIT + size: 211959 + timestamp: 1701647962657 +- kind: conda + name: lcms2 + version: '2.16' + build: ha2f27b4_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/lcms2-2.16-ha2f27b4_0.conda + sha256: 222ebc0a55544b9922f61e75015d02861e65b48f12113af41d48ba0814e14e4e + md5: 1442db8f03517834843666c422238c9b + depends: + - libjpeg-turbo >=3.0.0,<4.0a0 + - libtiff >=4.6.0,<4.7.0a0 + license: MIT + license_family: MIT + size: 224432 + timestamp: 1701648089496 +- kind: conda + name: lcms2 + version: '2.16' + build: hb7c19ff_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.16-hb7c19ff_0.conda + sha256: 5c878d104b461b7ef922abe6320711c0d01772f4cd55de18b674f88547870041 + md5: 51bb7010fc86f70eee639b4bb7a894f5 + depends: + - libgcc-ng >=12 + - libjpeg-turbo >=3.0.0,<4.0a0 + - libtiff >=4.6.0,<4.7.0a0 + license: MIT + license_family: MIT + size: 245247 + timestamp: 1701647787198 - kind: conda name: ld_impl_linux-64 version: '2.40' @@ -902,6 +1859,49 @@ packages: license_family: GPL size: 704696 timestamp: 1674833944779 +- kind: conda + name: lerc + version: 4.0.0 + build: h27087fc_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/lerc-4.0.0-h27087fc_0.tar.bz2 + sha256: cb55f36dcd898203927133280ae1dc643368af041a48bcf7c026acb7c47b0c12 + md5: 76bbff344f0134279f225174e9064c8f + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + license: Apache-2.0 + license_family: Apache + size: 281798 + timestamp: 1657977462600 +- kind: conda + name: lerc + version: 4.0.0 + build: h9a09cb3_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/lerc-4.0.0-h9a09cb3_0.tar.bz2 + sha256: 6f068bb53dfb6147d3147d981bb851bb5477e769407ad4e6a68edf482fdcb958 + md5: de462d5aacda3b30721b512c5da4e742 + depends: + - libcxx >=13.0.1 + license: Apache-2.0 + license_family: Apache + size: 215721 + timestamp: 1657977558796 +- kind: conda + name: lerc + version: 4.0.0 + build: hb486fe8_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/lerc-4.0.0-hb486fe8_0.tar.bz2 + sha256: e41790fc0f4089726369b3c7f813117bbc14b533e0ed8b94cf75aba252e82497 + md5: f9d6a4c82889d5ecedec1d90eb673c55 + depends: + - libcxx >=13.0.1 + license: Apache-2.0 + license_family: Apache + size: 290319 + timestamp: 1657977526749 - kind: conda name: libabseil version: '20230802.1' @@ -1016,8 +2016,156 @@ packages: - liblapacke 3.9.0 21_osxarm64_openblas license: BSD-3-Clause license_family: BSD - size: 14915 - timestamp: 1705980172730 + size: 14915 + timestamp: 1705980172730 +- kind: conda + name: libbrotlicommon + version: 1.1.0 + build: h0dc2134_1 + build_number: 1 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/libbrotlicommon-1.1.0-h0dc2134_1.conda + sha256: f57c57c442ef371982619f82af8735f93a4f50293022cfd1ffaf2ff89c2e0b2a + md5: 9e6c31441c9aa24e41ace40d6151aab6 + license: MIT + license_family: MIT + size: 67476 + timestamp: 1695990207321 +- kind: conda + name: libbrotlicommon + version: 1.1.0 + build: hb547adb_1 + build_number: 1 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/libbrotlicommon-1.1.0-hb547adb_1.conda + sha256: 556f0fddf4bd4d35febab404d98cb6862ce3b7ca843e393da0451bfc4654cf07 + md5: cd68f024df0304be41d29a9088162b02 + license: MIT + license_family: MIT + size: 68579 + timestamp: 1695990426128 +- kind: conda + name: libbrotlicommon + version: 1.1.0 + build: hd590300_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libbrotlicommon-1.1.0-hd590300_1.conda + sha256: 40f29d1fab92c847b083739af86ad2f36d8154008cf99b64194e4705a1725d78 + md5: aec6c91c7371c26392a06708a73c70e5 + depends: + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 69403 + timestamp: 1695990007212 +- kind: conda + name: libbrotlidec + version: 1.1.0 + build: h0dc2134_1 + build_number: 1 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/libbrotlidec-1.1.0-h0dc2134_1.conda + sha256: b11939c4c93c29448660ab5f63273216969d1f2f315dd9be60f3c43c4e61a50c + md5: 9ee0bab91b2ca579e10353738be36063 + depends: + - libbrotlicommon 1.1.0 h0dc2134_1 + license: MIT + license_family: MIT + size: 30327 + timestamp: 1695990232422 +- kind: conda + name: libbrotlidec + version: 1.1.0 + build: hb547adb_1 + build_number: 1 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/libbrotlidec-1.1.0-hb547adb_1.conda + sha256: c1c85937828ad3bc434ac60b7bcbde376f4d2ea4ee42d15d369bf2a591775b4a + md5: ee1a519335cc10d0ec7e097602058c0a + depends: + - libbrotlicommon 1.1.0 hb547adb_1 + license: MIT + license_family: MIT + size: 28928 + timestamp: 1695990463780 +- kind: conda + name: libbrotlidec + version: 1.1.0 + build: hd590300_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.1.0-hd590300_1.conda + sha256: 86fc861246fbe5ad85c1b6b3882aaffc89590a48b42d794d3d5c8e6d99e5f926 + md5: f07002e225d7a60a694d42a7bf5ff53f + depends: + - libbrotlicommon 1.1.0 hd590300_1 + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 32775 + timestamp: 1695990022788 +- kind: conda + name: libbrotlienc + version: 1.1.0 + build: h0dc2134_1 + build_number: 1 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/libbrotlienc-1.1.0-h0dc2134_1.conda + sha256: bc964c23e1a60ca1afe7bac38a9c1f2af3db4a8072c9f2eac4e4de537a844ac7 + md5: 8a421fe09c6187f0eb5e2338a8a8be6d + depends: + - libbrotlicommon 1.1.0 h0dc2134_1 + license: MIT + license_family: MIT + size: 299092 + timestamp: 1695990259225 +- kind: conda + name: libbrotlienc + version: 1.1.0 + build: hb547adb_1 + build_number: 1 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/libbrotlienc-1.1.0-hb547adb_1.conda + sha256: 690dfc98e891ee1871c54166d30f6e22edfc2d7d6b29e7988dde5f1ce271c81a + md5: d7e077f326a98b2cc60087eaff7c730b + depends: + - libbrotlicommon 1.1.0 hb547adb_1 + license: MIT + license_family: MIT + size: 280943 + timestamp: 1695990509392 +- kind: conda + name: libbrotlienc + version: 1.1.0 + build: hd590300_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.1.0-hd590300_1.conda + sha256: f751b8b1c4754a2a8dfdc3b4040fa7818f35bbf6b10e905a47d3a194b746b071 + md5: 5fc11c6020d421960607d821310fcd4d + depends: + - libbrotlicommon 1.1.0 hd590300_1 + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 282523 + timestamp: 1695990038302 +- kind: conda + name: libcap + version: '2.69' + build: h0f662aa_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libcap-2.69-h0f662aa_0.conda + sha256: 942f9564b4228609f017b6617425d29a74c43b8a030e12239fa4458e5cb6323c + md5: 25cb5999faa414e5ccb2c1388f62d3d5 + depends: + - attr >=2.5.1,<2.6.0a0 + - libgcc-ng >=12 + license: BSD-3-Clause + license_family: BSD + size: 100582 + timestamp: 1684162447012 - kind: conda name: libcblas version: 3.9.0 @@ -1075,6 +2223,59 @@ packages: license_family: BSD size: 14800 timestamp: 1705980195551 +- kind: conda + name: libclang + version: 15.0.7 + build: default_hb11cfb5_4 + build_number: 4 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_hb11cfb5_4.conda + sha256: 0b80441f222a91074d0e5edb0fbc3b1ce16ca2cdf6ab899721afdcc3a3ff6302 + md5: c90f4cbb57839c98fef8f830e4b9972f + depends: + - libclang13 15.0.7 default_ha2b6cf4_4 + - libgcc-ng >=12 + - libllvm15 >=15.0.7,<15.1.0a0 + - libstdcxx-ng >=12 + license: Apache-2.0 WITH LLVM-exception + license_family: Apache + size: 133384 + timestamp: 1701412265788 +- kind: conda + name: libclang13 + version: 15.0.7 + build: default_ha2b6cf4_4 + build_number: 4 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libclang13-15.0.7-default_ha2b6cf4_4.conda + sha256: e1d34d415160b69a401dc0662bf1b5378655193ed1364bf7dd14f055e76e4b60 + md5: 898e0dd993afbed0d871b60c2eb33b83 + depends: + - libgcc-ng >=12 + - libllvm15 >=15.0.7,<15.1.0a0 + - libstdcxx-ng >=12 + license: Apache-2.0 WITH LLVM-exception + license_family: Apache + size: 9581845 + timestamp: 1701412208888 +- kind: conda + name: libcups + version: 2.3.3 + build: h4637d8d_4 + build_number: 4 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h4637d8d_4.conda + sha256: bc67b9b21078c99c6bd8595fe7e1ed6da1f721007726e717f0449de7032798c4 + md5: d4529f4dff3057982a7617c7ac58fde3 + depends: + - krb5 >=1.21.1,<1.22.0a0 + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - libzlib >=1.2.13,<1.3.0a0 + license: Apache-2.0 + license_family: Apache + size: 4519402 + timestamp: 1689195353551 - kind: conda name: libcxx version: 16.0.6 @@ -1099,6 +2300,93 @@ packages: license_family: Apache size: 1142172 timestamp: 1686896907750 +- kind: conda + name: libdeflate + version: '1.19' + build: ha4e1b8e_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/libdeflate-1.19-ha4e1b8e_0.conda + sha256: d0f789120fedd0881b129aba9993ec5dcf0ecca67a71ea20c74394e41adcb503 + md5: 6a45f543c2beb40023df5ee7e3cedfbd + license: MIT + license_family: MIT + size: 68962 + timestamp: 1694922440450 +- kind: conda + name: libdeflate + version: '1.19' + build: hb547adb_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/libdeflate-1.19-hb547adb_0.conda + sha256: 6a3d188a6ae845a742dc85c5fb3f7eb1e252726cd74f0b8a7fa25ec09db6b87a + md5: f8c1eb0e99e90b55965c6558578537cc + license: MIT + license_family: MIT + size: 52841 + timestamp: 1694924330786 +- kind: conda + name: libdeflate + version: '1.19' + build: hd590300_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.19-hd590300_0.conda + sha256: 985ad27aa0ba7aad82afa88a8ede6a1aacb0aaca950d710f15d85360451e72fd + md5: 1635570038840ee3f9c71d22aa5b8b6d + depends: + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 67080 + timestamp: 1694922285678 +- kind: conda + name: libedit + version: 3.1.20191231 + build: he28a2e2_2 + build_number: 2 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20191231-he28a2e2_2.tar.bz2 + sha256: a57d37c236d8f7c886e01656f4949d9dcca131d2a0728609c6f7fa338b65f1cf + md5: 4d331e44109e3f0e19b4cb8f9b82f3e1 + depends: + - libgcc-ng >=7.5.0 + - ncurses >=6.2,<7.0.0a0 + license: BSD-2-Clause + license_family: BSD + size: 123878 + timestamp: 1597616541093 +- kind: conda + name: libevent + version: 2.1.12 + build: hf998b51_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libevent-2.1.12-hf998b51_1.conda + sha256: 2e14399d81fb348e9d231a82ca4d816bf855206923759b69ad006ba482764131 + md5: a1cfcc585f0c42bf8d5546bb1dfb668d + depends: + - libgcc-ng >=12 + - openssl >=3.1.1,<4.0a0 + license: BSD-3-Clause + license_family: BSD + size: 427426 + timestamp: 1685725977222 +- kind: conda + name: libexpat + version: 2.5.0 + build: hcb278e6_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.5.0-hcb278e6_1.conda + sha256: 74c98a563777ae2ad71f1f74d458a8ab043cee4a513467c159ccf159d0e461f3 + md5: 6305a3dd2752c76335295da4e581f2fd + depends: + - libgcc-ng >=12 + constrains: + - expat 2.5.0.* + license: MIT + license_family: MIT + size: 77980 + timestamp: 1680190528313 - kind: conda name: libffi version: 3.4.2 @@ -1140,6 +2428,24 @@ packages: license_family: MIT size: 58292 timestamp: 1636488182923 +- kind: conda + name: libflac + version: 1.4.3 + build: h59595ed_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libflac-1.4.3-h59595ed_0.conda + sha256: 65908b75fa7003167b8a8f0001e11e58ed5b1ef5e98b96ab2ba66d7c1b822c7d + md5: ee48bf17cc83a00f59ca1494d5646869 + depends: + - gettext >=0.21.1,<1.0a0 + - libgcc-ng >=12 + - libogg 1.3.* + - libogg >=1.3.4,<1.4.0a0 + - libstdcxx-ng >=12 + license: BSD-3-Clause + license_family: BSD + size: 394383 + timestamp: 1687765514062 - kind: conda name: libgcc-ng version: 13.2.0 @@ -1158,6 +2464,21 @@ packages: license_family: GPL size: 770506 timestamp: 1706819192021 +- kind: conda + name: libgcrypt + version: 1.10.3 + build: hd590300_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.3-hd590300_0.conda + sha256: d1bd47faa29fec7288c7b212198432b07f890d3d6f646078da93b059c2e9daff + md5: 32d16ad533c59bb0a3c5ffaf16110829 + depends: + - libgcc-ng >=12 + - libgpg-error >=1.47,<2.0a0 + license: LGPL-2.1-or-later AND GPL-2.0-or-later + license_family: GPL + size: 634887 + timestamp: 1701383493365 - kind: conda name: libgfortran version: 5.0.0 @@ -1254,6 +2575,27 @@ packages: license_family: GPL size: 997381 timestamp: 1707330687590 +- kind: conda + name: libglib + version: 2.78.4 + build: h783c2da_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.4-h783c2da_0.conda + sha256: 3a03a5254d2fd29c1e0ffda7250e22991dfbf2c854301fd56c408d97a647cfbd + md5: d86baf8740d1a906b9716f2a0bac2f2d + depends: + - gettext >=0.21.1,<1.0a0 + - libffi >=3.4,<4.0a0 + - libgcc-ng >=12 + - libiconv >=1.17,<2.0a0 + - libstdcxx-ng >=12 + - libzlib >=1.2.13,<1.3.0a0 + - pcre2 >=10.42,<10.43.0a0 + constrains: + - glib 2.78.4 *_0 + license: LGPL-2.1-or-later + size: 2692079 + timestamp: 1708284870228 - kind: conda name: libgomp version: 13.2.0 @@ -1269,6 +2611,22 @@ packages: license_family: GPL size: 419751 timestamp: 1706819107383 +- kind: conda + name: libgpg-error + version: '1.48' + build: h71f35ed_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.48-h71f35ed_0.conda + sha256: c448c6d86d27e10b9e844172000540e9cbfe9c28f968db87f949ba05add9bd50 + md5: 4d18d86916705d352d5f4adfb7f0edd3 + depends: + - gettext >=0.21.1,<1.0a0 + - libgcc-ng >=12 + - libstdcxx-ng >=12 + license: GPL-2.0-only + license_family: GPL + size: 266447 + timestamp: 1708702470365 - kind: conda name: libgrpc version: 1.59.3 @@ -1344,6 +2702,64 @@ packages: license_family: APACHE size: 6600132 timestamp: 1700259627150 +- kind: conda + name: libiconv + version: '1.17' + build: hd590300_2 + build_number: 2 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libiconv-1.17-hd590300_2.conda + sha256: 8ac2f6a9f186e76539439e50505d98581472fedb347a20e7d1f36429849f05c9 + md5: d66573916ffcf376178462f1b61c941e + depends: + - libgcc-ng >=12 + license: LGPL-2.1-only + size: 705775 + timestamp: 1702682170569 +- kind: conda + name: libjpeg-turbo + version: 3.0.0 + build: h0dc2134_1 + build_number: 1 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/libjpeg-turbo-3.0.0-h0dc2134_1.conda + sha256: d9572fd1024adc374aae7c247d0f29fdf4b122f1e3586fe62acc18067f40d02f + md5: 72507f8e3961bc968af17435060b6dd6 + constrains: + - jpeg <0.0.0a + license: IJG AND BSD-3-Clause AND Zlib + size: 579748 + timestamp: 1694475265912 +- kind: conda + name: libjpeg-turbo + version: 3.0.0 + build: hb547adb_1 + build_number: 1 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/libjpeg-turbo-3.0.0-hb547adb_1.conda + sha256: a42054eaa38e84fc1e5ab443facac4bbc9d1b6b6f23f54b7bf4f1eb687e1d993 + md5: 3ff1e053dc3a2b8e36b9bfa4256a58d1 + constrains: + - jpeg <0.0.0a + license: IJG AND BSD-3-Clause AND Zlib + size: 547541 + timestamp: 1694475104253 +- kind: conda + name: libjpeg-turbo + version: 3.0.0 + build: hd590300_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libjpeg-turbo-3.0.0-hd590300_1.conda + sha256: b954e09b7e49c2f2433d6f3bb73868eda5e378278b0f8c1dd10a7ef090e14f2f + md5: ea25936bb4080d843790b586850f82b8 + depends: + - libgcc-ng >=12 + constrains: + - jpeg <0.0.0a + license: IJG AND BSD-3-Clause AND Zlib + size: 618575 + timestamp: 1694474974816 - kind: conda name: liblapack version: 3.9.0 @@ -1401,6 +2817,25 @@ packages: license_family: BSD size: 14829 timestamp: 1705980215575 +- kind: conda + name: libllvm15 + version: 15.0.7 + build: hb3ce162_4 + build_number: 4 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libllvm15-15.0.7-hb3ce162_4.conda + sha256: e71584c0f910140630580fdd0a013029a52fd31e435192aea2aa8d29005262d1 + md5: 8a35df3cbc0c8b12cc8af9473ae75eef + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - libxml2 >=2.12.1,<3.0.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - zstd >=1.5.5,<1.6.0a0 + license: Apache-2.0 WITH LLVM-exception + license_family: Apache + size: 33321457 + timestamp: 1701375836233 - kind: conda name: libnsl version: 2.0.1 @@ -1415,6 +2850,21 @@ packages: license_family: GPL size: 33408 timestamp: 1697359010159 +- kind: conda + name: libogg + version: 1.3.4 + build: h7f98852_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libogg-1.3.4-h7f98852_1.tar.bz2 + sha256: b88afeb30620b11bed54dac4295aa57252321446ba4e6babd7dce4b9ffde9b25 + md5: 6e8cc2173440d77708196c5b93771680 + depends: + - libgcc-ng >=9.3.0 + license: BSD-3-Clause + license_family: BSD + size: 210550 + timestamp: 1610382007814 - kind: conda name: libopenblas version: 0.3.26 @@ -1469,6 +2919,76 @@ packages: license_family: BSD size: 5578031 timestamp: 1704950143521 +- kind: conda + name: libopus + version: 1.3.1 + build: h7f98852_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libopus-1.3.1-h7f98852_1.tar.bz2 + sha256: 0e1c2740ebd1c93226dc5387461bbcf8142c518f2092f3ea7551f77755decc8f + md5: 15345e56d527b330e1cacbdf58676e8f + depends: + - libgcc-ng >=9.3.0 + license: BSD-3-Clause + license_family: BSD + size: 260658 + timestamp: 1606823578035 +- kind: conda + name: libpng + version: 1.6.43 + build: h091b4b1_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/libpng-1.6.43-h091b4b1_0.conda + sha256: 66c4713b07408398f2221229a1c1d5df57d65dc0902258113f2d9ecac4772495 + md5: 77e684ca58d82cae9deebafb95b1a2b8 + depends: + - libzlib >=1.2.13,<1.3.0a0 + license: zlib-acknowledgement + size: 264177 + timestamp: 1708780447187 +- kind: conda + name: libpng + version: 1.6.43 + build: h2797004_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.43-h2797004_0.conda + sha256: 502f6ff148ac2777cc55ae4ade01a8fc3543b4ffab25c4e0eaa15f94e90dd997 + md5: 009981dd9cfcaa4dbfa25ffaed86bcae + depends: + - libgcc-ng >=12 + - libzlib >=1.2.13,<1.3.0a0 + license: zlib-acknowledgement + size: 288221 + timestamp: 1708780443939 +- kind: conda + name: libpng + version: 1.6.43 + build: h92b6c6a_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/libpng-1.6.43-h92b6c6a_0.conda + sha256: 13e646d24b5179e6b0a5ece4451a587d759f55d9a360b7015f8f96eff4524b8f + md5: 65dcddb15965c9de2c0365cb14910532 + depends: + - libzlib >=1.2.13,<1.3.0a0 + license: zlib-acknowledgement + size: 268524 + timestamp: 1708780496420 +- kind: conda + name: libpq + version: '16.2' + build: h33b98f1_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libpq-16.2-h33b98f1_0.conda + sha256: 352748b0499a22e2a8e103f071b8d9357e1fb710c0aec0f79895d3ba03dccb03 + md5: fe0e297faf462ee579c95071a5211665 + depends: + - krb5 >=1.21.2,<1.22.0a0 + - libgcc-ng >=12 + - openssl >=3.2.1,<4.0a0 + license: PostgreSQL + size: 2474825 + timestamp: 1707415138154 - kind: conda name: libprotobuf version: 4.24.4 @@ -1580,6 +3100,28 @@ packages: license_family: BSD size: 232708 timestamp: 1697065825934 +- kind: conda + name: libsndfile + version: 1.2.2 + build: hc60ed4a_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libsndfile-1.2.2-hc60ed4a_1.conda + sha256: f709cbede3d4f3aee4e2f8d60bd9e256057f410bd60b8964cb8cf82ec1457573 + md5: ef1910918dd895516a769ed36b5b3a4e + depends: + - lame >=3.100,<3.101.0a0 + - libflac >=1.4.3,<1.5.0a0 + - libgcc-ng >=12 + - libogg >=1.3.4,<1.4.0a0 + - libopus >=1.3.1,<2.0a0 + - libstdcxx-ng >=12 + - libvorbis >=1.3.7,<1.4.0a0 + - mpg123 >=1.32.1,<1.33.0a0 + license: LGPL-2.1-or-later + license_family: LGPL + size: 354372 + timestamp: 1695747735668 - kind: conda name: libsqlite version: 3.45.1 @@ -1608,45 +3150,276 @@ packages: size: 859346 timestamp: 1707495156652 - kind: conda - name: libsqlite - version: 3.45.1 - build: h92b6c6a_0 + name: libsqlite + version: 3.45.1 + build: h92b6c6a_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/libsqlite-3.45.1-h92b6c6a_0.conda + sha256: d65ce7093ecf5884b241a5ca8d26f80d21eaebf14ca67923b50c249f47a84cf9 + md5: e451d14a5412cdc68be50493df251f55 + depends: + - libzlib >=1.2.13,<1.3.0a0 + license: Unlicense + size: 902313 + timestamp: 1707495366004 +- kind: conda + name: libstdcxx-ng + version: 13.2.0 + build: h7e041cc_5 + build_number: 5 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_5.conda + sha256: a56c5b11f1e73a86e120e6141a42d9e935a99a2098491ac9e15347a1476ce777 + md5: f6f6600d18a4047b54f803cf708b868a + license: GPL-3.0-only WITH GCC-exception-3.1 + license_family: GPL + size: 3834139 + timestamp: 1706819252496 +- kind: conda + name: libsystemd0 + version: '255' + build: h3516f8a_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libsystemd0-255-h3516f8a_1.conda + sha256: af27b0d225435d03f378a119f8eab6b280c53557a3c84cdb3bb8fd3167615aed + md5: 3366af27f0b593544a6cd453c7932ac5 + depends: + - __glibc >=2.17,<3.0.a0 + - libcap >=2.69,<2.70.0a0 + - libgcc-ng >=12 + - libgcrypt >=1.10.3,<2.0a0 + - lz4-c >=1.9.3,<1.10.0a0 + - xz >=5.2.6,<6.0a0 + - zstd >=1.5.5,<1.6.0a0 + license: LGPL-2.1-or-later + size: 402592 + timestamp: 1709568499820 +- kind: conda + name: libtiff + version: 4.6.0 + build: h684deea_2 + build_number: 2 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/libtiff-4.6.0-h684deea_2.conda + sha256: 1ef5bd7295f4316b111f70ad21356fb9f0de50b85a341cac9e3a61ac6487fdf1 + md5: 2ca10a325063e000ad6d2a5900061e0d + depends: + - lerc >=4.0.0,<5.0a0 + - libcxx >=15.0.7 + - libdeflate >=1.19,<1.20.0a0 + - libjpeg-turbo >=3.0.0,<4.0a0 + - libwebp-base >=1.3.2,<2.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - xz >=5.2.6,<6.0a0 + - zstd >=1.5.5,<1.6.0a0 + license: HPND + size: 266501 + timestamp: 1695661828714 +- kind: conda + name: libtiff + version: 4.6.0 + build: ha8a6c65_2 + build_number: 2 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/libtiff-4.6.0-ha8a6c65_2.conda + sha256: b18ef36eb90f190db22c56ae5a080bccc16669c8f5b795a6211d7b0c00c18ff7 + md5: 596d6d949bab9a75a492d451f521f457 + depends: + - lerc >=4.0.0,<5.0a0 + - libcxx >=15.0.7 + - libdeflate >=1.19,<1.20.0a0 + - libjpeg-turbo >=3.0.0,<4.0a0 + - libwebp-base >=1.3.2,<2.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - xz >=5.2.6,<6.0a0 + - zstd >=1.5.5,<1.6.0a0 + license: HPND + size: 246265 + timestamp: 1695661829324 +- kind: conda + name: libtiff + version: 4.6.0 + build: ha9c0a0a_2 + build_number: 2 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.6.0-ha9c0a0a_2.conda + sha256: 45158f5fbee7ee3e257e6b9f51b9f1c919ed5518a94a9973fe7fa4764330473e + md5: 55ed21669b2015f77c180feb1dd41930 + depends: + - lerc >=4.0.0,<5.0a0 + - libdeflate >=1.19,<1.20.0a0 + - libgcc-ng >=12 + - libjpeg-turbo >=3.0.0,<4.0a0 + - libstdcxx-ng >=12 + - libwebp-base >=1.3.2,<2.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - xz >=5.2.6,<6.0a0 + - zstd >=1.5.5,<1.6.0a0 + license: HPND + size: 283198 + timestamp: 1695661593314 +- kind: conda + name: libuuid + version: 2.38.1 + build: h0b41bf4_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda + sha256: 787eb542f055a2b3de553614b25f09eefb0a0931b0c87dbcce6efdfd92f04f18 + md5: 40b61aab5c7ba9ff276c41cfffe6b80b + depends: + - libgcc-ng >=12 + license: BSD-3-Clause + license_family: BSD + size: 33601 + timestamp: 1680112270483 +- kind: conda + name: libvorbis + version: 1.3.7 + build: h9c3ff4c_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2 + sha256: 53080d72388a57b3c31ad5805c93a7328e46ff22fab7c44ad2a86d712740af33 + md5: 309dec04b70a3cc0f1e84a4013683bc0 + depends: + - libgcc-ng >=9.3.0 + - libogg >=1.3.4,<1.4.0a0 + - libstdcxx-ng >=9.3.0 + license: BSD-3-Clause + license_family: BSD + size: 286280 + timestamp: 1610609811627 +- kind: conda + name: libwebp-base + version: 1.3.2 + build: h0dc2134_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/libwebp-base-1.3.2-h0dc2134_0.conda + sha256: fa7580f26fec4c28321ec2ece1257f3293e0c646c635e9904679f4a8369be401 + md5: 4e7e9d244e87d66c18d36894fd6a8ae5 + constrains: + - libwebp 1.3.2 + license: BSD-3-Clause + license_family: BSD + size: 346599 + timestamp: 1694709233836 +- kind: conda + name: libwebp-base + version: 1.3.2 + build: hb547adb_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/libwebp-base-1.3.2-hb547adb_0.conda + sha256: a159b848193043fb58465ae6a449361615dadcf27babfe0b18db2bd3eb59e958 + md5: 85dbc11098cdbe4244cd73f29a3ab795 + constrains: + - libwebp 1.3.2 + license: BSD-3-Clause + license_family: BSD + size: 273844 + timestamp: 1694709510635 +- kind: conda + name: libwebp-base + version: 1.3.2 + build: hd590300_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libwebp-base-1.3.2-hd590300_0.conda + sha256: 68764a760fa81ef35dacb067fe8ace452bbb41476536a4a147a1051df29525f0 + md5: 30de3fd9b3b602f7473f30e684eeea8c + depends: + - libgcc-ng >=12 + constrains: + - libwebp 1.3.2 + license: BSD-3-Clause + license_family: BSD + size: 401830 + timestamp: 1694709121323 +- kind: conda + name: libxcb + version: '1.15' + build: h0b41bf4_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda + sha256: a670902f0a3173a466c058d2ac22ca1dd0df0453d3a80e0212815c20a16b0485 + md5: 33277193f5b92bad9fdd230eb700929c + depends: + - libgcc-ng >=12 + - pthread-stubs + - xorg-libxau + - xorg-libxdmcp + license: MIT + license_family: MIT + size: 384238 + timestamp: 1682082368177 +- kind: conda + name: libxcb + version: '1.15' + build: hb7f2c08_0 subdir: osx-64 - url: https://conda.anaconda.org/conda-forge/osx-64/libsqlite-3.45.1-h92b6c6a_0.conda - sha256: d65ce7093ecf5884b241a5ca8d26f80d21eaebf14ca67923b50c249f47a84cf9 - md5: e451d14a5412cdc68be50493df251f55 + url: https://conda.anaconda.org/conda-forge/osx-64/libxcb-1.15-hb7f2c08_0.conda + sha256: f41904f466acc8b3197f37f2dd3a08da75720c7f7464d9267635debc4ac1902b + md5: 5513f57e0238c87c12dffedbcc9c1a4a depends: - - libzlib >=1.2.13,<1.3.0a0 - license: Unlicense - size: 902313 - timestamp: 1707495366004 + - pthread-stubs + - xorg-libxau + - xorg-libxdmcp + license: MIT + license_family: MIT + size: 313793 + timestamp: 1682083036825 - kind: conda - name: libstdcxx-ng - version: 13.2.0 - build: h7e041cc_5 - build_number: 5 + name: libxcb + version: '1.15' + build: hf346824_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/libxcb-1.15-hf346824_0.conda + sha256: 6eaa87760ff3e91bb5524189700139db46f8946ff6331f4e571e4a9356edbb0d + md5: 988d5f86ab60fa6de91b3ee3a88a3af9 + depends: + - pthread-stubs + - xorg-libxau + - xorg-libxdmcp + license: MIT + license_family: MIT + size: 334770 + timestamp: 1682082734262 +- kind: conda + name: libxkbcommon + version: 1.6.0 + build: hd429924_1 + build_number: 1 subdir: linux-64 - url: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-13.2.0-h7e041cc_5.conda - sha256: a56c5b11f1e73a86e120e6141a42d9e935a99a2098491ac9e15347a1476ce777 - md5: f6f6600d18a4047b54f803cf708b868a - license: GPL-3.0-only WITH GCC-exception-3.1 - license_family: GPL - size: 3834139 - timestamp: 1706819252496 + url: https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.6.0-hd429924_1.conda + sha256: 213a4c927618198fd5fb5e7b0a76b89310a9c04a3ea025d59771754ee8a89451 + md5: 1dbcc04604fdf1e526e6d1b0b6938396 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - libxcb >=1.15,<1.16.0a0 + - libxml2 >=2.12.1,<3.0.0a0 + - xkeyboard-config + - xorg-libxau >=1.0.11,<2.0a0 + license: MIT/X11 Derivative + license_family: MIT + size: 574868 + timestamp: 1701352639132 - kind: conda - name: libuuid - version: 2.38.1 - build: h0b41bf4_0 + name: libxml2 + version: 2.12.5 + build: h232c23b_0 subdir: linux-64 - url: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda - sha256: 787eb542f055a2b3de553614b25f09eefb0a0931b0c87dbcce6efdfd92f04f18 - md5: 40b61aab5c7ba9ff276c41cfffe6b80b + url: https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.12.5-h232c23b_0.conda + sha256: db9bf97e9e367985204331b58a059ebd5a4e0cb9e1c8754e9ecb23046b7b7bc1 + md5: c442ebfda7a475f5e78f1c8e45f1e919 depends: + - icu >=73.2,<74.0a0 - libgcc-ng >=12 - license: BSD-3-Clause - license_family: BSD - size: 33601 - timestamp: 1680112270483 + - libiconv >=1.17,<2.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - xz >=5.2.6,<6.0a0 + license: MIT + license_family: MIT + size: 704829 + timestamp: 1707084502281 - kind: conda name: libzlib version: 1.2.13 @@ -1722,6 +3495,21 @@ packages: license_family: APACHE size: 274631 timestamp: 1701222947083 +- kind: conda + name: lz4-c + version: 1.9.4 + build: hcb278e6_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/lz4-c-1.9.4-hcb278e6_0.conda + sha256: 1b4c105a887f9b2041219d57036f72c4739ab9e9fe5a1486f094e58c76b31f5f + md5: 318b08df404f9c9be5712aaa5a6f0bb0 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + license: BSD-2-Clause + license_family: BSD + size: 143402 + timestamp: 1674727076728 - kind: conda name: markdown-it-py version: 3.0.0 @@ -1791,6 +3579,146 @@ packages: license_family: BSD size: 26155 timestamp: 1706900211496 +- kind: conda + name: matplotlib + version: 3.8.3 + build: py311h38be061_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-3.8.3-py311h38be061_0.conda + sha256: e3c4aed587c91fdd1ecc2a8ba50a774e1edc7ed4dd4451fcd59bf74f07b58b97 + md5: 0452c2cca94bdda38a16cf7b84edcd27 + depends: + - matplotlib-base >=3.8.3,<3.8.4.0a0 + - pyqt >=5.10 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + - tornado >=5 + license: PSF-2.0 + license_family: PSF + size: 8535 + timestamp: 1708026784226 +- kind: conda + name: matplotlib + version: 3.8.3 + build: py311h6eed73b_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/matplotlib-3.8.3-py311h6eed73b_0.conda + sha256: 029214f70506c5acd18377c74644b921a34d2b454bbd976787c46e668b11931c + md5: 30bdee405877d3291c38ffa5819e3166 + depends: + - matplotlib-base >=3.8.3,<3.8.4.0a0 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + - tornado >=5 + license: PSF-2.0 + license_family: PSF + size: 8613 + timestamp: 1708027016401 +- kind: conda + name: matplotlib + version: 3.8.3 + build: py311ha1ab1f8_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/matplotlib-3.8.3-py311ha1ab1f8_0.conda + sha256: 45183aaecb83400b0e435ae9c309844c50a83836a27946fd8c618888c79ae624 + md5: 2aea37eb7fb61fdf23356864d79a8720 + depends: + - matplotlib-base >=3.8.3,<3.8.4.0a0 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + - tornado >=5 + license: PSF-2.0 + license_family: PSF + size: 8665 + timestamp: 1708027249910 +- kind: conda + name: matplotlib-base + version: 3.8.3 + build: py311h54ef318_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.8.3-py311h54ef318_0.conda + sha256: 3b1d85d61b2c88e72449c1fb2fb0893522512d0924a50aca608ba58663253907 + md5: 014c115be880802d2372ac6ed665f526 + depends: + - certifi >=2020.06.20 + - contourpy >=1.0.1 + - cycler >=0.10 + - fonttools >=4.22.0 + - freetype >=2.12.1,<3.0a0 + - kiwisolver >=1.3.1 + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - numpy >=1.21,<2 + - numpy >=1.23.5,<2.0a0 + - packaging >=20.0 + - pillow >=8 + - pyparsing >=2.3.1 + - python >=3.11,<3.12.0a0 + - python-dateutil >=2.7 + - python_abi 3.11.* *_cp311 + - tk >=8.6.13,<8.7.0a0 + license: PSF-2.0 + license_family: PSF + size: 7927557 + timestamp: 1708026755428 +- kind: conda + name: matplotlib-base + version: 3.8.3 + build: py311h6ff1f5f_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/matplotlib-base-3.8.3-py311h6ff1f5f_0.conda + sha256: 8b317ebb64621325aa56630989a500c67dedc7512eec892de85fe9c676eadf9a + md5: 34a8ced9af5c6c771d5c18213151a639 + depends: + - __osx >=10.12 + - certifi >=2020.06.20 + - contourpy >=1.0.1 + - cycler >=0.10 + - fonttools >=4.22.0 + - freetype >=2.12.1,<3.0a0 + - kiwisolver >=1.3.1 + - libcxx >=16 + - numpy >=1.21,<2 + - numpy >=1.23.5,<2.0a0 + - packaging >=20.0 + - pillow >=8 + - pyparsing >=2.3.1 + - python >=3.11,<3.12.0a0 + - python-dateutil >=2.7 + - python_abi 3.11.* *_cp311 + license: PSF-2.0 + license_family: PSF + size: 7806156 + timestamp: 1708026973946 +- kind: conda + name: matplotlib-base + version: 3.8.3 + build: py311hb58f1d1_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/matplotlib-base-3.8.3-py311hb58f1d1_0.conda + sha256: 03a67cafc8a54b0b78170e6a770981aa7d0e657478a4ff394afd78cafe2a197f + md5: 2a92e691e859ebdd98d60a9664d42074 + depends: + - certifi >=2020.06.20 + - contourpy >=1.0.1 + - cycler >=0.10 + - fonttools >=4.22.0 + - freetype >=2.12.1,<3.0a0 + - kiwisolver >=1.3.1 + - libcxx >=16 + - numpy >=1.21,<2 + - numpy >=1.23.5,<2.0a0 + - packaging >=20.0 + - pillow >=8 + - pyparsing >=2.3.1 + - python >=3.11,<3.12.0a0 + - python >=3.11,<3.12.0a0 *_cpython + - python-dateutil >=2.7 + - python_abi 3.11.* *_cp311 + license: PSF-2.0 + license_family: PSF + size: 7773555 + timestamp: 1708027200046 - kind: conda name: matplotlib-inline version: 0.1.6 @@ -1888,6 +3816,103 @@ packages: license: MPL-2.0 AND Apache-2.0 size: 126895 timestamp: 1704728281985 +- kind: conda + name: mpg123 + version: 1.32.4 + build: h59595ed_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/mpg123-1.32.4-h59595ed_0.conda + sha256: 512f4ad7eda3b2c9a1cc9f7931932aefa6e79567e35b76de03895e769cb3b43c + md5: 3f1017b4141e943d9bc8739237f749e8 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + license: LGPL-2.1-only + license_family: LGPL + size: 491061 + timestamp: 1704980200966 +- kind: conda + name: mplhep + version: 0.3.35 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/mplhep-0.3.35-pyhd8ed1ab_0.conda + sha256: 66adbe0d7e3e86011d60dbae2e5319d20d7c1985684c3e8e982df98cb2a9ca9e + md5: ff9d3b72d14eb3390fd650435cae8c89 + depends: + - matplotlib-base >=3.4 + - mplhep_data + - numpy >=1.16.0 + - packaging + - python >=3.7 + - uhi >=0.2.0 + license: MIT + license_family: MIT + size: 36529 + timestamp: 1708713850581 +- kind: conda + name: mplhep_data + version: 0.0.3 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/mplhep_data-0.0.3-pyhd8ed1ab_0.tar.bz2 + sha256: 1c59385af1d2e39bc62f4543ebbf175f35384919846ead6ef525e2d7b59e284c + md5: 33d02b47a4a63aae3e1340fba09a8bb5 + depends: + - python >=3.7 + license: MIT AND OFL-1.1 AND LPPL-1.3c + size: 4988336 + timestamp: 1629888799240 +- kind: conda + name: munkres + version: 1.1.4 + build: pyh9f0ad1d_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2 + sha256: f86fb22b58e93d04b6f25e0d811b56797689d598788b59dcb47f59045b568306 + md5: 2ba8498c1018c1e9c61eb99b973dfe19 + depends: + - python + license: Apache-2.0 + license_family: Apache + size: 12452 + timestamp: 1600387789153 +- kind: conda + name: mysql-common + version: 8.0.33 + build: hf1915f5_6 + build_number: 6 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/mysql-common-8.0.33-hf1915f5_6.conda + sha256: c8b2c5c9d0d013a4f6ef96cb4b339bfdc53a74232d8c61ed08178e5b1ec4eb63 + md5: 80bf3b277c120dd294b51d404b931a75 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - openssl >=3.1.4,<4.0a0 + size: 753467 + timestamp: 1698937026421 +- kind: conda + name: mysql-libs + version: 8.0.33 + build: hca2cd23_6 + build_number: 6 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-8.0.33-hca2cd23_6.conda + sha256: 78c905637dac79b197395065c169d452b8ca2a39773b58e45e23114f1cb6dcdb + md5: e87530d1b12dd7f4e0f856dc07358d60 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - libzlib >=1.2.13,<1.3.0a0 + - mysql-common 8.0.33 hf1915f5_6 + - openssl >=3.1.4,<4.0a0 + - zstd >=1.5.5,<1.6.0a0 + size: 1530126 + timestamp: 1698937116126 - kind: conda name: myst-parser version: 2.0.0 @@ -1951,6 +3976,40 @@ packages: license: X11 AND BSD-3-Clause size: 822031 timestamp: 1698751567986 +- kind: conda + name: nspr + version: '4.35' + build: h27087fc_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/nspr-4.35-h27087fc_0.conda + sha256: 8fadeebb2b7369a4f3b2c039a980d419f65c7b18267ba0c62588f9f894396d0c + md5: da0ec11a6454ae19bff5b02ed881a2b1 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + license: MPL-2.0 + license_family: MOZILLA + size: 226848 + timestamp: 1669784948267 +- kind: conda + name: nss + version: '3.98' + build: h1d7d5a4_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/nss-3.98-h1d7d5a4_0.conda + sha256: a9bc94d03df48014011cf6caaf447f2ef86a5edf7c70d70002ec4b59f5a4e198 + md5: 54b56c2fdf973656b748e0378900ec13 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc-ng >=12 + - libsqlite >=3.45.1,<4.0a0 + - libstdcxx-ng >=12 + - libzlib >=1.2.13,<1.3.0a0 + - nspr >=4.35,<5.0a0 + license: MPL-2.0 + license_family: MOZILLA + size: 2019716 + timestamp: 1708065114928 - kind: conda name: numpy version: 1.26.4 @@ -2016,6 +4075,58 @@ packages: license_family: BSD size: 7504319 timestamp: 1707226235372 +- kind: conda + name: openjpeg + version: 2.5.2 + build: h488ebb8_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.2-h488ebb8_0.conda + sha256: 5600a0b82df042bd27d01e4e687187411561dfc11cc05143a08ce29b64bf2af2 + md5: 7f2e286780f072ed750df46dc2631138 + depends: + - libgcc-ng >=12 + - libpng >=1.6.43,<1.7.0a0 + - libstdcxx-ng >=12 + - libtiff >=4.6.0,<4.7.0a0 + - libzlib >=1.2.13,<1.3.0a0 + license: BSD-2-Clause + license_family: BSD + size: 341592 + timestamp: 1709159244431 +- kind: conda + name: openjpeg + version: 2.5.2 + build: h7310d3a_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/openjpeg-2.5.2-h7310d3a_0.conda + sha256: dc9c405119b9b54f8ca5984da27ba498bd848ab4f0f580da6f293009ca5adc13 + md5: 05a14cc9d725dd74995927968d6547e3 + depends: + - libcxx >=16 + - libpng >=1.6.43,<1.7.0a0 + - libtiff >=4.6.0,<4.7.0a0 + - libzlib >=1.2.13,<1.3.0a0 + license: BSD-2-Clause + license_family: BSD + size: 331273 + timestamp: 1709159538792 +- kind: conda + name: openjpeg + version: 2.5.2 + build: h9f1df11_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.2-h9f1df11_0.conda + sha256: 472d6eaffc1996e6af35ec8e91c967f472a536a470079bfa56383cc0dbf4d463 + md5: 5029846003f0bc14414b9128a1f7c84b + depends: + - libcxx >=16 + - libpng >=1.6.43,<1.7.0a0 + - libtiff >=4.6.0,<4.7.0a0 + - libzlib >=1.2.13,<1.3.0a0 + license: BSD-2-Clause + license_family: BSD + size: 316603 + timestamp: 1709159627299 - kind: conda name: openssl version: 3.2.1 @@ -2128,6 +4239,22 @@ packages: license_family: MIT size: 71048 timestamp: 1638335054552 +- kind: conda + name: pcre2 + version: '10.42' + build: hcad00b1_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.42-hcad00b1_0.conda + sha256: 3ca54ff0abcda964af7d4724d389ae20d931159ae1881cfe57ad4b0ab9e6a380 + md5: 679c8961826aa4b50653bce17ee52abe + depends: + - bzip2 >=1.0.8,<2.0a0 + - libgcc-ng >=12 + - libzlib >=1.2.13,<1.3.0a0 + license: BSD-3-Clause + license_family: BSD + size: 1017235 + timestamp: 1698610864983 - kind: conda name: pexpect version: 4.9.0 @@ -2159,6 +4286,77 @@ packages: license_family: MIT size: 9332 timestamp: 1602536313357 +- kind: conda + name: pillow + version: 10.2.0 + build: py311ha6c5da5_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/pillow-10.2.0-py311ha6c5da5_0.conda + sha256: 3cd4827d822c9888b672bfac9017e905348ac5bd2237a98b30a734ed6573b248 + md5: a5ccd7f2271f28b7d2de0b02b64e3796 + depends: + - freetype >=2.12.1,<3.0a0 + - lcms2 >=2.16,<3.0a0 + - libgcc-ng >=12 + - libjpeg-turbo >=3.0.0,<4.0a0 + - libtiff >=4.6.0,<4.7.0a0 + - libwebp-base >=1.3.2,<2.0a0 + - libxcb >=1.15,<1.16.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - openjpeg >=2.5.0,<3.0a0 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + - tk >=8.6.13,<8.7.0a0 + license: HPND + size: 41629216 + timestamp: 1704252244851 +- kind: conda + name: pillow + version: 10.2.0 + build: py311hb9c5795_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-10.2.0-py311hb9c5795_0.conda + sha256: c09ed761df062c62e83b78c66a1987a6a727fa45dd5fadde3b436ad5566c216e + md5: 97c499f0ac4792fb1e33295c9adfb351 + depends: + - freetype >=2.12.1,<3.0a0 + - lcms2 >=2.16,<3.0a0 + - libjpeg-turbo >=3.0.0,<4.0a0 + - libtiff >=4.6.0,<4.7.0a0 + - libwebp-base >=1.3.2,<2.0a0 + - libxcb >=1.15,<1.16.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - openjpeg >=2.5.0,<3.0a0 + - python >=3.11,<3.12.0a0 + - python >=3.11,<3.12.0a0 *_cpython + - python_abi 3.11.* *_cp311 + - tk >=8.6.13,<8.7.0a0 + license: HPND + size: 41593553 + timestamp: 1704252636313 +- kind: conda + name: pillow + version: 10.2.0 + build: py311hea5c87a_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/pillow-10.2.0-py311hea5c87a_0.conda + sha256: c3f3d2276943d5bf27d184df76dcef15ad120d23f9eea92e05340093acee98fc + md5: 1709b31ce50343c7a7b3940ed30cc429 + depends: + - freetype >=2.12.1,<3.0a0 + - lcms2 >=2.16,<3.0a0 + - libjpeg-turbo >=3.0.0,<4.0a0 + - libtiff >=4.6.0,<4.7.0a0 + - libwebp-base >=1.3.2,<2.0a0 + - libxcb >=1.15,<1.16.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - openjpeg >=2.5.0,<3.0a0 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + - tk >=8.6.13,<8.7.0a0 + license: HPND + size: 42176355 + timestamp: 1704252505386 - kind: conda name: pip version: '24.0' @@ -2176,6 +4374,37 @@ packages: license_family: MIT size: 1398245 timestamp: 1706960660581 +- kind: conda + name: pixman + version: 0.43.2 + build: h59595ed_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.43.2-h59595ed_0.conda + sha256: 366d28e2a0a191d6c535e234741e0cd1d94d713f76073d8af4a5ccb2a266121e + md5: 71004cbf7924e19c02746ccde9fd7123 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + license: MIT + license_family: MIT + size: 386826 + timestamp: 1706549500138 +- kind: conda + name: ply + version: '3.11' + build: py_1 + build_number: 1 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/ply-3.11-py_1.tar.bz2 + sha256: 2cd6fae8f9cbc806b7f828f006ae4a83c23fac917cacfd73c37ce322d4324e53 + md5: 7205635cd71531943440fbfe3b6b5727 + depends: + - python + license: BSD 3-clause + license_family: BSD + size: 44837 + timestamp: 1530963184592 - kind: conda name: prompt-toolkit version: 3.0.42 @@ -2194,6 +4423,47 @@ packages: license_family: BSD size: 270398 timestamp: 1702399557137 +- kind: conda + name: pthread-stubs + version: '0.4' + build: h27ca646_1001 + build_number: 1001 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-h27ca646_1001.tar.bz2 + sha256: 9da9e6f5d51dff6ad2e4ee0874791437ba952e0a6249942273f0fedfd07ea826 + md5: d3f26c6494d4105d4ecb85203d687102 + license: MIT + license_family: MIT + size: 5696 + timestamp: 1606147608402 +- kind: conda + name: pthread-stubs + version: '0.4' + build: h36c2ea0_1001 + build_number: 1001 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-h36c2ea0_1001.tar.bz2 + sha256: 67c84822f87b641d89df09758da498b2d4558d47b920fd1d3fe6d3a871e000ff + md5: 22dad4df6e8630e8dff2428f6f6a7036 + depends: + - libgcc-ng >=7.5.0 + license: MIT + license_family: MIT + size: 5625 + timestamp: 1606147468727 +- kind: conda + name: pthread-stubs + version: '0.4' + build: hc929b4f_1001 + build_number: 1001 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/pthread-stubs-0.4-hc929b4f_1001.tar.bz2 + sha256: 6e3900bb241bcdec513d4e7180fe9a19186c1a38f0b4080ed619d26014222c53 + md5: addd19059de62181cd11ae8f4ef26084 + license: MIT + license_family: MIT + size: 5653 + timestamp: 1606147699844 - kind: conda name: ptyprocess version: 0.7.0 @@ -2208,6 +4478,27 @@ packages: license: ISC size: 16546 timestamp: 1609419417991 +- kind: conda + name: pulseaudio-client + version: '16.1' + build: hb77b528_5 + build_number: 5 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-16.1-hb77b528_5.conda + sha256: 9981c70893d95c8cac02e7edd1a9af87f2c8745b772d529f08b7f9dafbe98606 + md5: ac902ff3c1c6d750dd0dfc93a974ab74 + depends: + - dbus >=1.13.6,<2.0a0 + - libgcc-ng >=12 + - libglib >=2.76.4,<3.0a0 + - libsndfile >=1.2.2,<1.3.0a0 + - libsystemd0 >=254 + constrains: + - pulseaudio 16.1 *_5 + license: LGPL-2.1-or-later + license_family: LGPL + size: 754844 + timestamp: 1693928953742 - kind: conda name: pure_eval version: 0.2.2 @@ -2233,11 +4524,67 @@ packages: sha256: af5f8867450dc292f98ea387d4d8945fc574284677c8f60eaa9846ede7387257 md5: 140a7f159396547e9799aa98f9f0742e depends: - - python >=3.7 - license: BSD-2-Clause - license_family: BSD - size: 860425 - timestamp: 1700608076927 + - python >=3.7 + license: BSD-2-Clause + license_family: BSD + size: 860425 + timestamp: 1700608076927 +- kind: conda + name: pyparsing + version: 3.1.2 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.1.2-pyhd8ed1ab_0.conda + sha256: 06c77cb03e5dde2d939b216c99dd2db52ea93a4c7c599f3882f136005c359c7b + md5: b9a4dacf97241704529131a0dfc0494f + depends: + - python >=3.6 + license: MIT + size: 89455 + timestamp: 1709721146886 +- kind: conda + name: pyqt + version: 5.15.9 + build: py311hf0fb5b6_5 + build_number: 5 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.9-py311hf0fb5b6_5.conda + sha256: 74fcdb8772c7eaf654b32922f77d9a8a1350b3446111c69a32ba4d15be74905a + md5: ec7e45bc76d9d0b69a74a2075932b8e8 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - pyqt5-sip 12.12.2 py311hb755f60_5 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + - qt-main >=5.15.8,<5.16.0a0 + - sip >=6.7.11,<6.8.0a0 + license: GPL-3.0-only + license_family: GPL + size: 5315719 + timestamp: 1695420475603 +- kind: conda + name: pyqt5-sip + version: 12.12.2 + build: py311hb755f60_5 + build_number: 5 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.12.2-py311hb755f60_5.conda + sha256: cf6936273d92e5213b085bfd9ce1a37defb46b317b6ee991f2712bf4a25b8456 + md5: e4d262cc3600e70b505a6761d29f6207 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - packaging + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + - sip + - toml + license: GPL-3.0-only + license_family: GPL + size: 85162 + timestamp: 1695418076285 - kind: conda name: pysocks version: 1.7.1 @@ -2334,6 +4681,22 @@ packages: license: Python-2.0 size: 15410083 timestamp: 1673762717308 +- kind: conda + name: python-dateutil + version: 2.9.0 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0-pyhd8ed1ab_0.conda + sha256: f3ceef02ac164a8d3a080d0d32f8e2ebe10dd29e3a685d240e38b3599e146320 + md5: 2cf4264fffb9e6eff6031c5b6884d61c + depends: + - python >=3.7 + - six >=1.5 + license: Apache-2.0 + license_family: APACHE + size: 222742 + timestamp: 1709299922152 - kind: conda name: python_abi version: '3.11' @@ -2447,6 +4810,65 @@ packages: license_family: MIT size: 187795 timestamp: 1695373829282 +- kind: conda + name: qt-main + version: 5.15.8 + build: h5810be5_19 + build_number: 19 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h5810be5_19.conda + sha256: 41228ec12346d640ef1f549885d8438e98b1be0fdeb68cd1dd3938f255cbd719 + md5: 54866f708d43002a514d0b9b0f84bc11 + depends: + - __glibc >=2.17,<3.0.a0 + - alsa-lib >=1.2.10,<1.3.0.0a0 + - dbus >=1.13.6,<2.0a0 + - fontconfig >=2.14.2,<3.0a0 + - fonts-conda-ecosystem + - freetype >=2.12.1,<3.0a0 + - gst-plugins-base >=1.22.9,<1.23.0a0 + - gstreamer >=1.22.9,<1.23.0a0 + - harfbuzz >=8.3.0,<9.0a0 + - icu >=73.2,<74.0a0 + - krb5 >=1.21.2,<1.22.0a0 + - libclang >=15.0.7,<16.0a0 + - libclang13 >=15.0.7 + - libcups >=2.3.3,<2.4.0a0 + - libevent >=2.1.12,<2.1.13.0a0 + - libexpat >=2.5.0,<3.0a0 + - libgcc-ng >=12 + - libglib >=2.78.3,<3.0a0 + - libjpeg-turbo >=3.0.0,<4.0a0 + - libpng >=1.6.42,<1.7.0a0 + - libpq >=16.2,<17.0a0 + - libsqlite >=3.45.1,<4.0a0 + - libstdcxx-ng >=12 + - libxcb >=1.15,<1.16.0a0 + - libxkbcommon >=1.6.0,<2.0a0 + - libxml2 >=2.12.5,<3.0a0 + - libzlib >=1.2.13,<1.3.0a0 + - mysql-libs >=8.0.33,<8.1.0a0 + - nspr >=4.35,<5.0a0 + - nss >=3.97,<4.0a0 + - openssl >=3.2.1,<4.0a0 + - pulseaudio-client >=16.1,<16.2.0a0 + - xcb-util >=0.4.0,<0.5.0a0 + - xcb-util-image >=0.4.0,<0.5.0a0 + - xcb-util-keysyms >=0.4.0,<0.5.0a0 + - xcb-util-renderutil >=0.3.9,<0.4.0a0 + - xcb-util-wm >=0.4.1,<0.5.0a0 + - xorg-libice >=1.1.1,<2.0a0 + - xorg-libsm >=1.2.4,<2.0a0 + - xorg-libx11 >=1.8.7,<2.0a0 + - xorg-libxext >=1.3.4,<2.0a0 + - xorg-xf86vidmodeproto + - zstd >=1.5.5,<1.6.0a0 + constrains: + - qt 5.15.8 + license: LGPL-3.0-only + license_family: LGPL + size: 61337596 + timestamp: 1707958161584 - kind: conda name: re2 version: 2023.06.02 @@ -2647,6 +5069,26 @@ packages: license_family: MIT size: 469644 timestamp: 1708702431036 +- kind: conda + name: sip + version: 6.7.12 + build: py311hb755f60_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/sip-6.7.12-py311hb755f60_0.conda + sha256: 71a0ee22522b232bf50d4d03d012e53cd5d1251d09dffc1c72d7c33a1086fe6f + md5: 02336abab4cb5dd794010ef53c54bd09 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - packaging + - ply + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + - tomli + license: GPL-3.0-only + license_family: GPL + size: 585197 + timestamp: 1697300605264 - kind: conda name: six version: 1.16.0 @@ -2868,6 +5310,83 @@ packages: license_family: BSD size: 3318875 timestamp: 1699202167581 +- kind: conda + name: toml + version: 0.10.2 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 + sha256: f0f3d697349d6580e4c2f35ba9ce05c65dc34f9f049e85e45da03800b46139c1 + md5: f832c45a477c78bebd107098db465095 + depends: + - python >=2.7 + license: MIT + license_family: MIT + size: 18433 + timestamp: 1604308660817 +- kind: conda + name: tomli + version: 2.0.1 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2 + sha256: 4cd48aba7cd026d17e86886af48d0d2ebc67ed36f87f6534f4b67138f5a5a58f + md5: 5844808ffab9ebdb694585b50ba02a96 + depends: + - python >=3.7 + license: MIT + license_family: MIT + size: 15940 + timestamp: 1644342331069 +- kind: conda + name: tornado + version: '6.4' + build: py311h05b510d_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/tornado-6.4-py311h05b510d_0.conda + sha256: 29c07a81b52310f9679ca05a6f1d3d3ee8c1830f183f91ad8d46f99cc2fb6720 + md5: 241cd427ab1f38b72d6ddda3994c80a7 + depends: + - python >=3.11,<3.12.0a0 + - python >=3.11,<3.12.0a0 *_cpython + - python_abi 3.11.* *_cp311 + license: Apache-2.0 + license_family: Apache + size: 856729 + timestamp: 1708363632330 +- kind: conda + name: tornado + version: '6.4' + build: py311h459d7ec_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.4-py311h459d7ec_0.conda + sha256: 5bb1e24d1767e403183e4cc842d184b2da497e778f0311c5b1d023fb3af9e6b6 + md5: cc7727006191b8f3630936b339a76cd0 + depends: + - libgcc-ng >=12 + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: Apache-2.0 + license_family: Apache + size: 853245 + timestamp: 1708363316040 +- kind: conda + name: tornado + version: '6.4' + build: py311he705e18_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/tornado-6.4-py311he705e18_0.conda + sha256: 0b994ce7984953d1d528b7e19a97db0b34da09398feaf592500df637719d5623 + md5: 40845aadca8df7ccc21c393ba3aa9eac + depends: + - python >=3.11,<3.12.0a0 + - python_abi 3.11.* *_cp311 + license: Apache-2.0 + license_family: Apache + size: 857610 + timestamp: 1708363541170 - kind: conda name: traitlets version: 5.14.1 @@ -2910,6 +5429,23 @@ packages: license: LicenseRef-Public-Domain size: 119815 timestamp: 1706886945727 +- kind: conda + name: uhi + version: 0.4.0 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/uhi-0.4.0-pyhd8ed1ab_0.conda + sha256: b91db36c45531787666bb9b486702a268c63fcccd5153acd4257e30ff510fdc6 + md5: aad63b5bb7b7eff6bf4d3c4d731b8d38 + depends: + - numpy >=1.13.3 + - python >=3.6 + - typing_extensions >=3.7 + license: BSD-3-Clause + license_family: BSD + size: 16910 + timestamp: 1697567564540 - kind: conda name: urllib3 version: 2.2.1 @@ -2957,6 +5493,336 @@ packages: license_family: MIT size: 57553 timestamp: 1701013309664 +- kind: conda + name: xcb-util + version: 0.4.0 + build: hd590300_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.0-hd590300_1.conda + sha256: 0c91d87f0efdaadd4e56a5f024f8aab20ec30f90aa2ce9e4ebea05fbc20f71ad + md5: 9bfac7ccd94d54fd21a0501296d60424 + depends: + - libgcc-ng >=12 + - libxcb >=1.13 + - libxcb >=1.15,<1.16.0a0 + license: MIT + license_family: MIT + size: 19728 + timestamp: 1684639166048 +- kind: conda + name: xcb-util-image + version: 0.4.0 + build: h8ee46fc_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-image-0.4.0-h8ee46fc_1.conda + sha256: 92ffd68d2801dbc27afe223e04ae7e78ef605fc8575f107113c93c7bafbd15b0 + md5: 9d7bcddf49cbf727730af10e71022c73 + depends: + - libgcc-ng >=12 + - libxcb >=1.15,<1.16.0a0 + - xcb-util >=0.4.0,<0.5.0a0 + license: MIT + license_family: MIT + size: 24474 + timestamp: 1684679894554 +- kind: conda + name: xcb-util-keysyms + version: 0.4.0 + build: h8ee46fc_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.0-h8ee46fc_1.conda + sha256: 8451d92f25d6054a941b962179180728c48c62aab5bf20ac10fef713d5da6a9a + md5: 632413adcd8bc16b515cab87a2932913 + depends: + - libgcc-ng >=12 + - libxcb >=1.15,<1.16.0a0 + license: MIT + license_family: MIT + size: 14186 + timestamp: 1684680497805 +- kind: conda + name: xcb-util-renderutil + version: 0.3.9 + build: hd590300_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.9-hd590300_1.conda + sha256: 6987588e6fff5892056021c2ea52f7a0deefb2c7348e70d24750e2d60dabf009 + md5: e995b155d938b6779da6ace6c6b13816 + depends: + - libgcc-ng >=12 + - libxcb >=1.13 + - libxcb >=1.15,<1.16.0a0 + license: MIT + license_family: MIT + size: 16955 + timestamp: 1684639112393 +- kind: conda + name: xcb-util-wm + version: 0.4.1 + build: h8ee46fc_1 + build_number: 1 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.1-h8ee46fc_1.conda + sha256: 08ba7147c7579249b6efd33397dc1a8c2404278053165aaecd39280fee705724 + md5: 90108a432fb5c6150ccfee3f03388656 + depends: + - libgcc-ng >=12 + - libxcb >=1.15,<1.16.0a0 + license: MIT + license_family: MIT + size: 52114 + timestamp: 1684679248466 +- kind: conda + name: xkeyboard-config + version: '2.41' + build: hd590300_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.41-hd590300_0.conda + sha256: 56955610c0747ea7cb026bb8aa9ef165ff41d616e89894538173b8b7dd2ee49a + md5: 81f740407b45e3f9047b3174fa94eb9e + depends: + - libgcc-ng >=12 + - xorg-libx11 >=1.8.7,<2.0a0 + license: MIT + license_family: MIT + size: 898045 + timestamp: 1707104384997 +- kind: conda + name: xorg-kbproto + version: 1.0.7 + build: h7f98852_1002 + build_number: 1002 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-kbproto-1.0.7-h7f98852_1002.tar.bz2 + sha256: e90b0a6a5d41776f11add74aa030f789faf4efd3875c31964d6f9cfa63a10dd1 + md5: 4b230e8381279d76131116660f5a241a + depends: + - libgcc-ng >=9.3.0 + license: MIT + license_family: MIT + size: 27338 + timestamp: 1610027759842 +- kind: conda + name: xorg-libice + version: 1.1.1 + build: hd590300_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-libice-1.1.1-hd590300_0.conda + sha256: 5aa9b3682285bb2bf1a8adc064cb63aff76ef9178769740d855abb42b0d24236 + md5: b462a33c0be1421532f28bfe8f4a7514 + depends: + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 58469 + timestamp: 1685307573114 +- kind: conda + name: xorg-libsm + version: 1.2.4 + build: h7391055_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.4-h7391055_0.conda + sha256: 089ad5f0453c604e18985480218a84b27009e9e6de9a0fa5f4a20b8778ede1f1 + md5: 93ee23f12bc2e684548181256edd2cf6 + depends: + - libgcc-ng >=12 + - libuuid >=2.38.1,<3.0a0 + - xorg-libice >=1.1.1,<2.0a0 + license: MIT + license_family: MIT + size: 27433 + timestamp: 1685453649160 +- kind: conda + name: xorg-libx11 + version: 1.8.7 + build: h8ee46fc_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.7-h8ee46fc_0.conda + sha256: 7a02a7beac472ae2759498550b5fc5261bf5be7a9a2b4648a3f67818a7bfefcf + md5: 49e482d882669206653b095f5206c05b + depends: + - libgcc-ng >=12 + - libxcb >=1.15,<1.16.0a0 + - xorg-kbproto + - xorg-xextproto >=7.3.0,<8.0a0 + - xorg-xproto + license: MIT + license_family: MIT + size: 828692 + timestamp: 1697056910935 +- kind: conda + name: xorg-libxau + version: 1.0.11 + build: h0dc2134_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/xorg-libxau-1.0.11-h0dc2134_0.conda + sha256: 8a2e398c4f06f10c64e69f56bcf3ddfa30b432201446a0893505e735b346619a + md5: 9566b4c29274125b0266d0177b5eb97b + license: MIT + license_family: MIT + size: 13071 + timestamp: 1684638167647 +- kind: conda + name: xorg-libxau + version: 1.0.11 + build: hb547adb_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxau-1.0.11-hb547adb_0.conda + sha256: 02c313a1cada46912e5b9bdb355cfb4534bfe22143b4ea4ecc419690e793023b + md5: ca73dc4f01ea91e44e3ed76602c5ea61 + license: MIT + license_family: MIT + size: 13667 + timestamp: 1684638272445 +- kind: conda + name: xorg-libxau + version: 1.0.11 + build: hd590300_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.11-hd590300_0.conda + sha256: 309751371d525ce50af7c87811b435c176915239fc9e132b99a25d5e1703f2d4 + md5: 2c80dc38fface310c9bd81b17037fee5 + depends: + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 14468 + timestamp: 1684637984591 +- kind: conda + name: xorg-libxdmcp + version: 1.1.3 + build: h27ca646_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxdmcp-1.1.3-h27ca646_0.tar.bz2 + sha256: d9a2fb4762779994718832f05a7d62ab2dcf6103a312235267628b5187ce88f7 + md5: 6738b13f7fadc18725965abdd4129c36 + license: MIT + license_family: MIT + size: 18164 + timestamp: 1610071737668 +- kind: conda + name: xorg-libxdmcp + version: 1.1.3 + build: h35c211d_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/xorg-libxdmcp-1.1.3-h35c211d_0.tar.bz2 + sha256: 485421c16f03a01b8ed09984e0b2ababdbb3527e1abf354ff7646f8329be905f + md5: 86ac76d6bf1cbb9621943eb3bd9ae36e + license: MIT + license_family: MIT + size: 17225 + timestamp: 1610071995461 +- kind: conda + name: xorg-libxdmcp + version: 1.1.3 + build: h7f98852_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxdmcp-1.1.3-h7f98852_0.tar.bz2 + sha256: 4df7c5ee11b8686d3453e7f3f4aa20ceef441262b49860733066c52cfd0e4a77 + md5: be93aabceefa2fac576e971aef407908 + depends: + - libgcc-ng >=9.3.0 + license: MIT + license_family: MIT + size: 19126 + timestamp: 1610071769228 +- kind: conda + name: xorg-libxext + version: 1.3.4 + build: h0b41bf4_2 + build_number: 2 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.4-h0b41bf4_2.conda + sha256: 73e5cfbdff41ef8a844441f884412aa5a585a0f0632ec901da035a03e1fe1249 + md5: 82b6df12252e6f32402b96dacc656fec + depends: + - libgcc-ng >=12 + - xorg-libx11 >=1.7.2,<2.0a0 + - xorg-xextproto + license: MIT + license_family: MIT + size: 50143 + timestamp: 1677036907815 +- kind: conda + name: xorg-libxrender + version: 0.9.11 + build: hd590300_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrender-0.9.11-hd590300_0.conda + sha256: 26da4d1911473c965c32ce2b4ff7572349719eaacb88a066db8d968a4132c3f7 + md5: ed67c36f215b310412b2af935bf3e530 + depends: + - libgcc-ng >=12 + - xorg-libx11 >=1.8.6,<2.0a0 + - xorg-renderproto + license: MIT + license_family: MIT + size: 37770 + timestamp: 1688300707994 +- kind: conda + name: xorg-renderproto + version: 0.11.1 + build: h7f98852_1002 + build_number: 1002 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-renderproto-0.11.1-h7f98852_1002.tar.bz2 + sha256: 38942930f233d1898594dd9edf4b0c0786f3dbc12065a0c308634c37fd936034 + md5: 06feff3d2634e3097ce2fe681474b534 + depends: + - libgcc-ng >=9.3.0 + license: MIT + license_family: MIT + size: 9621 + timestamp: 1614866326326 +- kind: conda + name: xorg-xextproto + version: 7.3.0 + build: h0b41bf4_1003 + build_number: 1003 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-xextproto-7.3.0-h0b41bf4_1003.conda + sha256: b8dda3b560e8a7830fe23be1c58cc41f407b2e20ae2f3b6901eb5842ba62b743 + md5: bce9f945da8ad2ae9b1d7165a64d0f87 + depends: + - libgcc-ng >=12 + license: MIT + license_family: MIT + size: 30270 + timestamp: 1677036833037 +- kind: conda + name: xorg-xf86vidmodeproto + version: 2.3.1 + build: h7f98852_1002 + build_number: 1002 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-xf86vidmodeproto-2.3.1-h7f98852_1002.tar.bz2 + sha256: 43398aeacad5b8753b7a1c12cb6bca36124e0c842330372635879c350c430791 + md5: 3ceea9668625c18f19530de98b15d5b0 + depends: + - libgcc-ng >=9.3.0 + license: MIT + license_family: MIT + size: 23875 + timestamp: 1620067286978 +- kind: conda + name: xorg-xproto + version: 7.0.31 + build: h7f98852_1007 + build_number: 1007 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/xorg-xproto-7.0.31-h7f98852_1007.tar.bz2 + sha256: f197bb742a17c78234c24605ad1fe2d88b1d25f332b75d73e5ba8cf8fbc2a10d + md5: b4a4381d54784606820704f7b5f05a15 + depends: + - libgcc-ng >=9.3.0 + license: MIT + license_family: MIT + size: 74922 + timestamp: 1607291557628 - kind: conda name: xz version: 5.2.6 @@ -3048,3 +5914,63 @@ packages: license_family: MIT size: 18954 timestamp: 1695255262261 +- kind: conda + name: zlib + version: 1.2.13 + build: hd590300_5 + build_number: 5 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.13-hd590300_5.conda + sha256: 9887a04d7e7cb14bd2b52fa01858f05a6d7f002c890f618d9fcd864adbfecb1b + md5: 68c34ec6149623be41a1933ab996a209 + depends: + - libgcc-ng >=12 + - libzlib 1.2.13 hd590300_5 + license: Zlib + license_family: Other + size: 92825 + timestamp: 1686575231103 +- kind: conda + name: zstd + version: 1.5.5 + build: h4f39d0f_0 + subdir: osx-arm64 + url: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.5-h4f39d0f_0.conda + sha256: 7e1fe6057628bbb56849a6741455bbb88705bae6d6646257e57904ac5ee5a481 + md5: 5b212cfb7f9d71d603ad891879dc7933 + depends: + - libzlib >=1.2.13,<1.3.0a0 + license: BSD-3-Clause + license_family: BSD + size: 400508 + timestamp: 1693151393180 +- kind: conda + name: zstd + version: 1.5.5 + build: h829000d_0 + subdir: osx-64 + url: https://conda.anaconda.org/conda-forge/osx-64/zstd-1.5.5-h829000d_0.conda + sha256: d54e31d3d8de5e254c0804abd984807b8ae5cd3708d758a8bf1adff1f5df166c + md5: 80abc41d0c48b82fe0f04e7f42f5cb7e + depends: + - libzlib >=1.2.13,<1.3.0a0 + license: BSD-3-Clause + license_family: BSD + size: 499383 + timestamp: 1693151312586 +- kind: conda + name: zstd + version: 1.5.5 + build: hfc55251_0 + subdir: linux-64 + url: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.5-hfc55251_0.conda + sha256: 607cbeb1a533be98ba96cf5cdf0ddbb101c78019f1fda063261871dad6248609 + md5: 04b88013080254850d6c01ed54810589 + depends: + - libgcc-ng >=12 + - libstdcxx-ng >=12 + - libzlib >=1.2.13,<1.3.0a0 + license: BSD-3-Clause + license_family: BSD + size: 545199 + timestamp: 1693151163452 diff --git a/pixi.toml b/pixi.toml index 578af90..f648bba 100644 --- a/pixi.toml +++ b/pixi.toml @@ -16,6 +16,9 @@ ipython = "*" jaxlib = "*" jax = "*" myst-parser = "*" +matplotlib = "*" +mplhep = "*" +imageio = "*" [host-dependencies] pip = "*" From 7808fea0f1f00306306fbd741f9eb25a677c70a6 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Thu, 7 Mar 2024 15:39:40 +0100 Subject: [PATCH 09/22] add helper function to reduce boilerplate code --- examples/model.py | 4 +--- src/evermore/__init__.py | 2 ++ src/evermore/parameter.py | 14 ++++++++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/model.py b/examples/model.py index 25cebf8..d9bb097 100644 --- a/examples/model.py +++ b/examples/model.py @@ -17,9 +17,7 @@ class SPlusBModel(eqx.Module): def __init__(self) -> None: self.mu = evm.Parameter(value=jnp.array([1.0])) - self.norm1 = evm.Parameter() - self.norm2 = evm.Parameter() - self.shape1 = evm.Parameter() + self = evm.parameter.auto_init(self) def __call__(self, hists: dict[Any, jax.Array]) -> dict[str, jax.Array]: expectations = {} diff --git a/src/evermore/__init__.py b/src/evermore/__init__.py index f9a52d6..00c2642 100644 --- a/src/evermore/__init__.py +++ b/src/evermore/__init__.py @@ -19,6 +19,7 @@ "__version__", "effect", "loss", + "parameter", "pdf", "util", "sample", @@ -38,6 +39,7 @@ def __dir__(): from evermore import ( # noqa: E402 effect, loss, + parameter, pdf, sample, util, diff --git a/src/evermore/parameter.py b/src/evermore/parameter.py index 195f239..76fce2b 100644 --- a/src/evermore/parameter.py +++ b/src/evermore/parameter.py @@ -14,6 +14,7 @@ __all__ = [ "Parameter", + "auto_init", ] @@ -76,3 +77,16 @@ def shape(self, up: Array, down: Array) -> modifier: import evermore as evm return evm.modifier(parameter=self, effect=evm.effect.shape(up=up, down=down)) + + +def auto_init(module: eqx.Module) -> eqx.Module: + import dataclasses + import typing + + type_hints = typing.get_type_hints(module.__class__) + for field in dataclasses.fields(module): + name = field.name + hint = type_hints[name] + if issubclass(hint, Parameter) and not hasattr(module, name): + setattr(module, name, hint()) + return module From 61e5f2a534a5d6ea25c3574b507639bd616881eb Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Thu, 7 Mar 2024 15:42:31 +0100 Subject: [PATCH 10/22] add non-ready examples to gitignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index e0a725d..e991edd 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,10 @@ Thumbs.db # if examples are run examples/*.eqx +# analysis optimisation example +examples/analysis_opt.py +plots_analysis_opt/ + test/ .vscode/ .zed/ From 77f12e73e722c3785530ecda54ce7a5a99df84cd Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Fri, 8 Mar 2024 15:52:27 +0100 Subject: [PATCH 11/22] properly add bin_by_bin staterrors --- examples/bin_by_bin_uncs.py | 68 +++++ examples/dnn_weights_constraint.py | 10 +- examples/model.py | 64 +++-- examples/nll_fit.py | 18 +- examples/toy_generation.py | 4 +- src/evermore/__init__.py | 17 +- src/evermore/custom_types.py | 2 +- src/evermore/effect.py | 33 +-- src/evermore/loss.py | 12 +- src/evermore/modifier.py | 432 +++++++++-------------------- src/evermore/parameter.py | 87 ++++-- src/evermore/pdf.py | 114 ++------ src/evermore/sample.py | 16 +- tests/test_parameter.py | 38 +-- tests/test_pdf.py | 10 +- 15 files changed, 422 insertions(+), 503 deletions(-) create mode 100644 examples/bin_by_bin_uncs.py diff --git a/examples/bin_by_bin_uncs.py b/examples/bin_by_bin_uncs.py new file mode 100644 index 0000000..b7d16cb --- /dev/null +++ b/examples/bin_by_bin_uncs.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import equinox as eqx +import jax.numpy as jnp +import jax.tree_util as jtu +from jaxtyping import Array, PyTree + +import evermore as evm + + +class SPlusBModel(eqx.Module): + mu: evm.Parameter + norm1: evm.Parameter + norm2: evm.Parameter + staterrors: PyTree[evm.Parameter] + + def __init__(self, hists: dict[str, Array]) -> None: + self.mu = evm.Parameter(value=1.0, lower=0.0, upper=10.0) + self.staterrors = evm.parameter.staterrors(hists=hists) + self = evm.parameter.auto_init(self) + + def __call__(self, hists: dict, histsw2: dict) -> dict[str, Array]: + expectations = {} + + # calculate widths of the sum of the nominal histograms for gaussian MC stat + sqrtw2 = jtu.tree_map(jnp.sqrt, histsw2) + widths = evm.util.sum_leaves(sqrtw2) / evm.util.sum_leaves(hists) + gauss_mcstat = self.staterrors["gauss"].gauss(widths) + # barlow-beeston-like condition: above 10 use gauss, below use poisson + mask = evm.util.sum_leaves(hists) > 10 + + # signal process + signal_poisson = self.staterrors["poisson"]["signal"].poisson(hists["signal"]) + signal_mc_stats = evm.modifier.where(mask, gauss_mcstat, signal_poisson) + mu_mod = self.mu.unconstrained() + expectations["signal"] = (signal_mc_stats @ mu_mod)(hists["signal"]) + + # bkg1 process + bkg1_poisson = self.staterrors["poisson"]["bkg1"].poisson(hists["bkg1"]) + bkg1_mc_stats = evm.modifier.where(mask, gauss_mcstat, bkg1_poisson) + norm1_mod = self.norm1.lnN(jnp.array([0.9, 1.1])) + expectations["bkg1"] = (bkg1_mc_stats @ norm1_mod)(hists["bkg1"]) + + # bkg2 process + bkg2_poisson = self.staterrors["poisson"]["bkg2"].poisson(hists["bkg2"]) + bkg2_mc_stats = evm.modifier.where(mask, gauss_mcstat, bkg2_poisson) + norm2_mod = self.norm2.lnN(jnp.array([0.95, 1.05])) + expectations["bkg2"] = (bkg2_mc_stats @ norm2_mod)(hists["bkg2"]) + + # return the modified expectations + return expectations + + +hists = { + "signal": jnp.array([3]), + "bkg1": jnp.array([10]), + "bkg2": jnp.array([20]), +} +histsw2 = { + "signal": jnp.array([5]), + "bkg1": jnp.array([11]), + "bkg2": jnp.array([25]), +} + +model = SPlusBModel(hists) + +# test the model +expectations = model(hists, histsw2) diff --git a/examples/dnn_weights_constraint.py b/examples/dnn_weights_constraint.py index 62d2a4c..f3bb737 100644 --- a/examples/dnn_weights_constraint.py +++ b/examples/dnn_weights_constraint.py @@ -12,8 +12,14 @@ class LinearConstrained(eqx.Module): def __init__(self, in_size, out_size, key): wkey, bkey = jax.random.split(key) # weights - self.weights = evm.Parameter(value=jax.random.normal(wkey, (out_size, in_size))) - self.weights.constraints.add(evm.pdf.Gauss(mean=0.0, width=0.5)) + constraint = evm.pdf.Gauss( + mean=jnp.zeros((out_size, in_size)), + width=jnp.full((out_size, in_size), 0.5), + ) + self.weights = evm.Parameter( + value=jax.random.normal(wkey, (out_size, in_size)), constraint=constraint + ) + # biases self.biases = jax.random.normal(bkey, (out_size,)) diff --git a/examples/model.py b/examples/model.py index d9bb097..e0f4886 100644 --- a/examples/model.py +++ b/examples/model.py @@ -1,10 +1,8 @@ from __future__ import annotations -from typing import Any - import equinox as eqx -import jax import jax.numpy as jnp +from jaxtyping import Array import evermore as evm @@ -15,47 +13,65 @@ class SPlusBModel(eqx.Module): norm2: evm.Parameter shape1: evm.Parameter - def __init__(self) -> None: + def __init__(self, sumw: dict[str, Array], sumw2: dict[str, Array]) -> None: self.mu = evm.Parameter(value=jnp.array([1.0])) self = evm.parameter.auto_init(self) - def __call__(self, hists: dict[Any, jax.Array]) -> dict[str, jax.Array]: + def __call__(self, hists: dict) -> dict[str, Array]: expectations = {} # signal process sig_mod = self.mu.unconstrained() - expectations["signal"] = sig_mod(hists[("signal", "nominal")]) + expectations["signal"] = sig_mod(hists["nominal"]["signal"]) # bkg1 process - bkg1_mod = self.norm1.lnN(width=jnp.array([0.9, 1.1])) @ self.shape1.shape( - up=hists[("bkg1", "shape_up")], - down=hists[("bkg1", "shape_down")], + bkg1_lnN = self.norm1.lnN(width=jnp.array([0.9, 1.1])) + bkg1_shape = self.shape1.shape( + up=hists["shape_up"]["bkg1"], + down=hists["shape_down"]["bkg1"], ) - expectations["bkg1"] = bkg1_mod(hists[("bkg1", "nominal")]) + # combine modifiers + bkg1_mod = bkg1_lnN @ bkg1_shape + expectations["bkg1"] = bkg1_mod(hists["nominal"]["bkg1"]) # bkg2 process - bkg2_mod = self.norm2.lnN(width=jnp.array([0.95, 1.05])) @ self.shape1.shape( - up=hists[("bkg2", "shape_up")], - down=hists[("bkg2", "shape_down")], + bkg2_lnN = self.norm2.lnN(width=jnp.array([0.95, 1.05])) + bkg2_shape = self.shape1.shape( + up=hists["shape_up"]["bkg2"], + down=hists["shape_down"]["bkg2"], ) - expectations["bkg2"] = bkg2_mod(hists[("bkg2", "nominal")]) + # combine modifiers + bkg2_mod = bkg2_lnN @ bkg2_shape + expectations["bkg2"] = bkg2_mod(hists["nominal"]["bkg2"]) # return the modified expectations return expectations -model = SPlusBModel() - - hists = { - ("signal", "nominal"): jnp.array([3]), - ("bkg1", "nominal"): jnp.array([10]), - ("bkg2", "nominal"): jnp.array([20]), - ("bkg1", "shape_up"): jnp.array([12]), - ("bkg1", "shape_down"): jnp.array([8]), - ("bkg2", "shape_up"): jnp.array([23]), - ("bkg2", "shape_down"): jnp.array([19]), + "nominal": { + "signal": jnp.array([3]), + "bkg1": jnp.array([10]), + "bkg2": jnp.array([20]), + }, + "shape_up": { + "bkg1": jnp.array([12]), + "bkg2": jnp.array([23]), + }, + "shape_down": { + "bkg1": jnp.array([8]), + "bkg2": jnp.array([19]), + }, } +sumw = hists["nominal"] +sumw2 = { + "signal": jnp.array([5]), + "bkg1": jnp.array([11]), + "bkg2": jnp.array([25]), +} + +model = SPlusBModel(sumw, sumw2) + observation = jnp.array([37]) expectations = model(hists) diff --git a/examples/nll_fit.py b/examples/nll_fit.py index 3d42579..4f30654 100644 --- a/examples/nll_fit.py +++ b/examples/nll_fit.py @@ -1,4 +1,5 @@ import equinox as eqx +import jax import jax.numpy as jnp import optax from model import hists, model, observation @@ -35,5 +36,20 @@ def make_step(model, opt_state, events, observation): # minimize model with 1000 steps -for _ in range(1000): +for step in range(1000): + if step % 100 == 0: + loss_val = loss(model, hists, observation) + print(f"{step=} - {loss_val=:.6f}") model, opt_state = make_step(model, opt_state, hists, observation) + + +# For low overhead it is recommended to use jax.lax.fori_loop. +# In case you want to jit the for loop, you can use the following function, +# this will prevent jax from unrolling the loop and creating a huge graph +@jax.jit +def fit(steps: int = 1000) -> tuple[eqx.Module, tuple]: + def fun(step, model_optstate): + model, opt_state = model_optstate + return make_step(model, opt_state, hists, observation) + + return jax.lax.fori_loop(0, steps, fun, (model, opt_state)) diff --git a/examples/toy_generation.py b/examples/toy_generation.py index e7038d9..96ad8ef 100644 --- a/examples/toy_generation.py +++ b/examples/toy_generation.py @@ -1,5 +1,3 @@ -from typing import Any - import equinox as eqx import jax from jaxtyping import Array, PRNGKeyArray @@ -17,7 +15,7 @@ def toy_expectation( key: PRNGKeyArray, module: eqx.Module, - hists: dict[Any, Array], + hists: dict, ) -> Array: toymodel = evm.sample.toy_module(model, key) expectations = toymodel(hists) diff --git a/src/evermore/__init__.py b/src/evermore/__init__.py index 00c2642..9b18ef0 100644 --- a/src/evermore/__init__.py +++ b/src/evermore/__init__.py @@ -23,12 +23,10 @@ "pdf", "util", "sample", + "modifier", # explicitely expose some classes "Parameter", - "modifier", - # "staterror", - # "autostaterrors", - "compose", + "Modifier", ] @@ -39,18 +37,11 @@ def __dir__(): from evermore import ( # noqa: E402 effect, loss, + modifier, parameter, pdf, sample, util, ) - -# from evermore.model import Model, Result -from evermore.modifier import ( # noqa: E402 - # autostaterrors, - compose, - modifier, -) - -# staterror, +from evermore.modifier import Modifier # noqa: E402 from evermore.parameter import Parameter # noqa: E402 diff --git a/src/evermore/custom_types.py b/src/evermore/custom_types.py index b901250..737ec9d 100644 --- a/src/evermore/custom_types.py +++ b/src/evermore/custom_types.py @@ -7,7 +7,7 @@ class Sentinel: - __slots__ = ("repr",) + repr: str def __init__(self, repr: str) -> None: self.repr = repr diff --git a/src/evermore/effect.py b/src/evermore/effect.py index 9fb4642..6256a9d 100644 --- a/src/evermore/effect.py +++ b/src/evermore/effect.py @@ -8,7 +8,7 @@ from evermore.custom_types import AddOrMul from evermore.parameter import Parameter -from evermore.pdf import Flat, Gauss, HashablePDF, Poisson +from evermore.pdf import PDF, Flat, Gauss, Poisson from evermore.util import as1darray if TYPE_CHECKING: @@ -34,9 +34,8 @@ def __dir__(): class Effect(eqx.Module): apply_op: AbstractClassVar[AddOrMul] - @property @abc.abstractmethod - def constraint(self) -> HashablePDF: + def constraint(self, parameter: Parameter) -> PDF: ... @abc.abstractmethod @@ -47,8 +46,7 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: class unconstrained(Effect): apply_op: ClassVar[AddOrMul] = operator.mul - @property - def constraint(self) -> HashablePDF: + def constraint(self, parameter: Parameter) -> PDF: return Flat() def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: @@ -66,9 +64,10 @@ class gauss(Effect): def __init__(self, width: Array) -> None: self.width = width - @property - def constraint(self) -> HashablePDF: - return Gauss(mean=0.0, width=1.0) + def constraint(self, parameter: Parameter) -> PDF: + return Gauss( + mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) + ) def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: """ @@ -124,9 +123,10 @@ def vshift(self, sf: Array, sumw: Array) -> Array: ) ) - @property - def constraint(self) -> HashablePDF: - return Gauss(mean=0.0, width=1.0) + def constraint(self, parameter: Parameter) -> PDF: + return Gauss( + mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) + ) def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: sf = parameter.value @@ -165,9 +165,10 @@ def interpolate(self, parameter: Parameter) -> Array: jnp.abs(x) >= 0.5, jnp.where(x >= 0, hi, lo), avg + alpha * halfdiff ) - @property - def constraint(self) -> HashablePDF: - return Gauss(mean=0.0, width=1.0) + def constraint(self, parameter: Parameter) -> PDF: + return Gauss( + mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) + ) def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: """ @@ -198,8 +199,8 @@ class poisson(Effect): def __init__(self, lamb: Array) -> None: self.lamb = lamb - @property - def constraint(self) -> HashablePDF: + def constraint(self, parameter: Parameter) -> PDF: + assert parameter.value.shape == self.lamb.shape return Poisson(lamb=self.lamb) def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: diff --git a/src/evermore/loss.py b/src/evermore/loss.py index e638aa9..e8209dc 100644 --- a/src/evermore/loss.py +++ b/src/evermore/loss.py @@ -1,11 +1,14 @@ from collections.abc import Callable +from typing import cast import equinox as eqx import jax import jax.numpy as jnp from jaxtyping import Array +from evermore.custom_types import _NoValue from evermore.parameter import Parameter +from evermore.pdf import PDF from evermore.util import _params_map __all__ = [ @@ -22,11 +25,10 @@ def get_param_constraints(module: eqx.Module) -> dict: constraints = {} def _constraint(param: Parameter) -> Array: - if param.constraints: - if len(param.constraints) > 1: - msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}" - raise ValueError(msg) - return next(iter(param.constraints)).logpdf(param.value) + constraint = param.constraint + if constraint is not _NoValue: + constraint = cast(PDF, constraint) + return constraint.logpdf(param.value) return jnp.array([0.0]) # constraints from pdfs diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index 422effa..f8eee16 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -1,9 +1,8 @@ from __future__ import annotations -import abc import operator from functools import reduce -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import equinox as eqx import jax @@ -20,10 +19,9 @@ from evermore.effect import Effect __all__ = [ - "modifier", + "Modifier", "compose", - # "staterror", - # "autostaterrors", + "where", ] @@ -31,13 +29,7 @@ def __dir__(): return __all__ -class ModifierBase(eqx.Module): - @abc.abstractmethod - def __call__(self, sumw: Array) -> Array: - ... - - -class modifier(ModifierBase): +class Modifier(eqx.Module): """ Create a new modifier for a given parameter and penalty. @@ -48,30 +40,35 @@ class modifier(ModifierBase): import jax.numpy as jnp import evermore as evm - mu = evm.Parameter(value=1.1, bounds=(0, 100)) - norm = evm.Parameter(value=0.0, bounds=(-jnp.inf, jnp.inf)) + mu = evm.Parameter(value=1.1) + norm = evm.Parameter(value=0.0) # create a new parameter and a penalty - modify = evm.modifier(name="mu", parameter=mu, effect=evm.effect.unconstrained()) + modify = evm.modifier(parameter=mu, effect=evm.effect.unconstrained()) + # or shorthand + modify = mu.unconstrained() # apply the modifier modify(jnp.array([10, 20, 30])) # -> Array([11., 22., 33.], dtype=float32, weak_type=True), # lnN effect - modify = evm.modifier(name="norm", parameter=norm, effect=evm.effect.lnN((0.8, 1.2))) - modify(jnp.array([10, 20, 30])) + modify = evm.modifier(parameter=norm, effect=evm.effect.lnN(jnp.array([0.8, 1.2]))) + # or shorthand + modify = norm.lnN(jnp.array([0.8, 1.2])) # poisson effect hist = jnp.array([10, 20, 30]) - modify = evm.modifier(name="norm", parameter=norm, effect=evm.effect.poisson(hist)) - modify(jnp.array([10, 20, 30])) + modify = evm.modifier(parameter=norm, effect=evm.effect.poisson(hist)) + # or shorthand + modify = norm.poisson(hist) # shape effect up = jnp.array([12, 23, 35]) down = jnp.array([8, 19, 26]) - modify = evm.modifier(name="norm", parameter=norm, effect=evm.effect.shape(up, down)) - modify(jnp.array([10, 20, 30])) + modify = evm.modifier(parameter=norm, effect=evm.effect.shape(up, down)) + # or shorthand + modify = norm.shape(up, down) """ parameter: Parameter @@ -80,7 +77,10 @@ class modifier(ModifierBase): def __init__(self, parameter: Parameter, effect: Effect = DEFAULT_EFFECT) -> None: self.parameter = parameter self.effect = effect - self.parameter.constraints.add(self.effect.constraint) + + # first time: set the constraint pdf + constraint = self.effect.constraint(parameter=self.parameter) + self.parameter._set_constraint(constraint, overwrite=False) def scale_factor(self, sumw: Array) -> Array: return self.effect.scale_factor(parameter=self.parameter, sumw=sumw) @@ -92,11 +92,62 @@ def __call__(self, sumw: Array) -> Array: shift = jnp.broadcast_to(shift, sumw.shape) return op(shift, sumw) # type: ignore[call-arg] - def __matmul__(self, other: modifier) -> compose: + def __matmul__(self, other: Composable) -> compose: return compose(self, other) -class compose(ModifierBase): +class where(eqx.Module): + """ + Combine two modifiers based on a condition. + + The condition is a boolean array, and the two modifiers are applied to the data based on the condition. + + Example: + + .. code-block:: python + + import jax.numpy as jnp + import evermore as evm + + hist = jnp.array([5, 20, 30]) + syst = evm.Parameter(value=0.0) + + norm = syst.lnN(jnp.array([0.9, 1.1])) + shape = syst.shape(up=jnp.array([7, 22, 31]), down=jnp.array([4, 16, 27])) + + modifier = evm.modifier.where(hist < 10, norm, shape) + + # apply + modifier(hist) + """ + + condition: Array = eqx.field(static=True) + modifier_true: Modifier + modifier_false: Modifier + + def scale_factor(self, sumw: Array) -> Array: + return jnp.where( + self.condition, + self.modifier_true.scale_factor(sumw), + self.modifier_false.scale_factor(sumw), + ) + + @jax.named_scope("evm.where") + def __call__(self, sumw: Array) -> Array: + op_true = self.modifier_true.effect.apply_op + op_false = self.modifier_false.effect.apply_op + sf = self.scale_factor(sumw=sumw) + return jnp.where( + self.condition, + op_true(jnp.atleast_1d(sf), sumw), # type: ignore[call-arg] + op_false(jnp.atleast_1d(sf), sumw), # type: ignore[call-arg] + ) + + def __matmul__(self, other: Composable) -> compose: + return compose(self, other) + + +class compose(eqx.Module): """ Composition of multiple modifiers, i.e.: `(f ∘ g ∘ h)(hist) = f(hist) * g(hist) * h(hist)` It behaves like a single modifier, but it is composed of multiple modifiers; it can be arbitrarly nested. @@ -108,33 +159,44 @@ class compose(ModifierBase): import jax.numpy as jnp import evermore as evm - mu = evm.Parameter(value=1.1, bounds=(0, 100)) - sigma = evm.Parameter(value=0.1, bounds=(-100, 100)) + mu = evm.Parameter(value=1.1) + sigma = evm.Parameter(value=0.1) + sigma2 = evm.Parameter(value=0.2) + + hist = jnp.array([10, 20, 30]) + + # all bins with bin content below 10 (threshold) are treated as poisson, else gauss # create a new parameter and a composition of modifiers + mu_mod = mu.constrained() + sigma_mod = sigma.lnN(jnp.array([0.9, 1.1])) + sigma2_mod = sigma2.lnN(jnp.array([0.95, 1.05])) composition = evm.compose( - evm.modifier(name="mu", parameter=mu), - evm.modifier(name="sigma1", parameter=sigma, effect=evm.effect.lnN((0.9, 1.1))), + mu_mod, + sigma_mod, + evm.modifier.where(hist < 15, sigma2_mod, sigma_mod), ) + # or shorthand + composition = mu_mod @ sigma_mod @ evm.modifier.where(hist < 15, sigma2_mod, sigma_mod) # apply the composition - composition(jnp.array([10, 20, 30])) + composition(hist) # nest compositions composition = evm.compose( composition, - evm.modifier(name="sigma2", parameter=sigma, effect=evm.effect.lnN((0.8, 1.2))), + evm.modifier(parameter=sigma, effect=evm.effect.lnN(jnp.array([0.8, 1.2]))), ) # jit import equinox as eqx - eqx.filter_jit(composition)(jnp.array([10, 20, 30])) + eqx.filter_jit(composition)(hist) """ - modifiers: list[modifier] + modifiers: list[Composable] - def __init__(self, *modifiers: modifier) -> None: + def __init__(self, *modifiers: Composable) -> None: self.modifiers = list(modifiers) # unroll nested compositions _modifiers = [] @@ -142,8 +204,9 @@ def __init__(self, *modifiers: modifier) -> None: if isinstance(mod, compose): _modifiers.extend(mod.modifiers) else: - assert isinstance(mod, modifier) + assert isinstance(mod, Modifier | where) _modifiers.append(mod) + # by now all modifiers are either modifier or staterror self.modifiers = _modifiers def __len__(self) -> int: @@ -151,7 +214,7 @@ def __len__(self) -> int: @jax.named_scope("evm.compose") def __call__(self, sumw: Array) -> Array: - def _prep_shift(modifier: modifier, sumw: Array) -> Array: + def _prep_shift(modifier: Modifier | where, sumw: Array) -> Array: shift = modifier.scale_factor(sumw=sumw) shift = jnp.atleast_1d(shift) return jnp.broadcast_to(shift, sumw.shape) @@ -159,12 +222,42 @@ def _prep_shift(modifier: modifier, sumw: Array) -> Array: # collect all multiplicative and additive shifts shifts: dict[AddOrMul, list] = {operator.mul: [], operator.add: []} for m in range(len(self)): - modifier = self.modifiers[m] - if modifier.effect.apply_op is operator.mul: - shifts[operator.mul].append(_prep_shift(modifier, sumw)) - elif modifier.effect.apply_op is operator.add: - shifts[operator.add].append(_prep_shift(modifier, sumw)) - + mod = self.modifiers[m] + # cast to modifier | staterror, we know it is one of them + # because we unrolled nested compositions in __init__ + mod = cast(Modifier | where, mod) + sf = _prep_shift(mod, sumw) + if isinstance(mod, Modifier): + if mod.effect.apply_op is operator.mul: + shifts[operator.mul].append(sf) + elif mod.effect.apply_op is operator.add: + shifts[operator.add].append(sf) + else: + msg = f"Unsupported apply_op {mod.effect.apply_op} for Modifier {mod}. Only multiplicative and additive effects are supported." + raise ValueError(msg) + elif isinstance(mod, where): + op_true = mod.modifier_true.effect.apply_op + op_false = mod.modifier_false.effect.apply_op + # if both modifiers are multiplicative: + if op_true is operator.mul and op_false is operator.mul: + shifts[operator.mul].append(sf) + # if both modifiers are additive: + elif op_true is operator.add and op_false is operator.add: + shifts[operator.add].append(sf) + # if one is multiplicative and the other is additive: + elif op_true is operator.mul and op_false is operator.add: + _mult_sf = jnp.where(mod.condition, sf, 1.0) + _add_sf = jnp.where(mod.condition, sf, 0.0) + shifts[operator.mul].append(_mult_sf) + shifts[operator.add].append(_add_sf) + elif op_true is operator.add and op_false is operator.mul: + _mult_sf = jnp.where(mod.condition, 1.0, sf) + _add_sf = jnp.where(mod.condition, 0.0, sf) + shifts[operator.mul].append(_mult_sf) + shifts[operator.add].append(_add_sf) + else: + msg = f"Unsupported apply_op {op_true} and {op_false} for 'where' Modifier {mod}. Only multiplicative and additive effects are supported." + raise ValueError(msg) # calculate the product with for operator.mul _mult_fact = reduce(operator.mul, shifts[operator.mul], 1.0) # calculate the sum for operator.add @@ -172,257 +265,8 @@ def _prep_shift(modifier: modifier, sumw: Array) -> Array: # apply return _mult_fact * (sumw + _add_shift) + def __matmul__(self, other: Composable) -> compose: + return compose(self, other) + -# class staterror(ModifierBase): -# """ -# Create a staterror (barlow-beeston) modifier which acts on each bin with a different _underlying_ modifier. - -# *Caution*: The instantiation of a `staterror` is not compatible with JAX-transformations (e.g. `jax.jit`)! - -# Example: - -# .. code-block:: python - -# import jax.numpy as jnp -# import evermore as evm - -# hist = jnp.array([10, 20, 30]) - -# p1 = evm.Parameter(value=1.0) -# p2 = evm.Parameter(value=0.0) -# p3 = evm.Parameter(value=0.0) - -# # all bins with bin content below 10 (threshold) are treated as poisson, else gauss -# modify = evm.staterror( -# parameters={1: p1, 2: p2, 3: p3}, -# sumw=hist, -# sumw2=hist, -# threshold=10.0, -# ) -# modify(hist) -# # -> Array([13.162277, 20. , 30. ], dtype=float32) - -# # jit -# import equinox as eqx - -# fast_modify = eqx.filter_jit(modify) -# """ - -# parameters: dict[str, Parameter] -# sumw: Array -# sumw2: Array -# sumw2sqrt: Array -# widths: Array -# mask: Array -# threshold: float - -# def __init__( -# self, -# parameters: dict[str, Parameter], -# sumw: Array, -# sumw2: Array, -# threshold: float, -# ) -> None: -# self.parameters = parameters -# self.sumw = sumw -# self.sumw2 = sumw2 -# self.sumw2sqrt = jnp.sqrt(sumw2) -# self.threshold = threshold - -# # calculate width -# self.widths = self.sumw2sqrt / self.sumw - -# # store if sumw is below threshold -# self.mask = self.sumw < self.threshold - -# for i, name in enumerate(self.parameters): -# param = self.parameters[name] -# effect = poisson(self.sumw[i]) if self.mask[i] else gauss(self.widths[i]) -# param.constraints.add(effect.constraint) - -# def __check_init__(self): -# if not len(self.parameters) == len(self.sumw2) == len(self.sumw): -# msg = ( -# f"Length of parameters ({len(self.parameters)}), " -# f"sumw2 ({len(self.sumw2)}) and sumw ({len(self.sumw)}) " -# "must be the same." -# ) -# raise ValueError(msg) -# if not self.threshold > 0.0: -# msg = f"Threshold must be >= 0.0, got: {self.threshold}" -# raise ValueError(msg) - -# def scale_factor(self, sumw: Array) -> Array: -# from functools import partial - -# assert len(sumw) == len(self.parameters) == len(self.sumw2) - -# values = jnp.concatenate([param.value for param in self.parameters.values()]) -# idxs = jnp.arange(len(sumw)) - -# # sumw where mask (poisson) else widths (gauss) -# _widths = jnp.where(self.mask, self.sumw, self.widths) - -# def _mod( -# value: Array, -# width: Array, -# idx: Array, -# effect: type[poisson] | type[gauss], -# ) -> Array: -# return effect(width).scale_factor( -# parameter=Parameter(value=value), -# sumw=sumw[idx], -# )[0] - -# _poisson_mod = partial(_mod, effect=poisson) -# _gauss_mod = partial(_mod, effect=gauss) - -# # apply -# return jnp.where( -# self.mask, -# jax.vmap(_poisson_mod)(values, _widths, idxs), -# jax.vmap(_gauss_mod)(values, _widths, idxs), -# ) - -# def __call__(self, sumw: Array) -> Array: -# # both gauss and poisson behave multiplicative -# op = operator.mul -# sf = self.scale_factor(sumw=sumw) -# return op(jnp.atleast_1d(sf), sumw) - - -# class autostaterrors(eqx.Module): -# class Mode(eqx.Enumeration): -# barlow_beeston_full = ( -# "Barlow-Beeston (full) approach: Poisson per process and bin" -# ) -# poisson_gauss = "Poisson (Gauss) per process and bin if sumw < (>) threshold" -# barlow_beeston_lite = "Barlow-Beeston (lite) approach" - -# sumw: dict[str, Array] -# sumw2: dict[str, Array] -# masks: dict[str, Array] -# threshold: float -# mode: str -# key_template: str = eqx.field(static=True) - -# def __init__( -# self, -# sumw: dict[str, Array], -# sumw2: dict[str, Array], -# threshold: float = 10.0, -# mode: str = Mode.barlow_beeston_lite, -# key_template: str = "__staterror_{process}__", -# ) -> None: -# self.sumw = sumw -# self.sumw2 = sumw2 -# self.masks = {p: _sumw < threshold for p, _sumw in sumw.items()} -# self.threshold = threshold -# self.mode = mode -# self.key_template = key_template - -# def __check_init__(self): -# if jax.tree_util.tree_structure(self.sumw) != jax.tree_util.tree_structure( -# self.sumw2 -# ): # type: ignore[operator] -# msg = ( -# "The structure of `sumw` and `sumw2` needs to be identical, got " -# f"`sumw`: {jax.tree_util.tree_structure(self.sumw)}) and " -# f"`sumw2`: {jax.tree_util.tree_structure(self.sumw2)})" -# ) -# raise ValueError(msg) -# if not self.threshold > 0.0: -# msg = f"Threshold must be >= 0.0, got: {self.threshold}" -# raise ValueError(msg) -# if not isinstance(self.mode, self.Mode): -# msg = f"Mode must be of type {self.Mode}, got: {self.mode}" -# raise ValueError(msg) - -# def prepare( -# self, -# ) -> tuple[dict[str, dict[str, Parameter]], dict[str, dict[str, eqx.Partial]]]: -# """ -# Helper to automatically create parameters used by `staterror` -# for the initialisation of a `evm.Model`. - -# *Caution*: This function is not compatible with JAX-transformations (e.g. `jax.jit`)! - -# Example: - -# .. code-block:: python - -# import jax.numpy as jnp -# import evermore as evm - -# sumw = { -# "signal": jnp.array([5, 20, 30]), -# "background": jnp.array([5, 20, 30]), -# } - -# sumw2 = { -# "signal": jnp.array([5, 20, 30]), -# "background": jnp.array([5, 20, 30]), -# } - - -# auto = evm.autostaterrors( -# sumw=sumw, -# sumw2=sumw2, -# threshold=10.0, -# mode=evm.autostaterrors.Mode.barlow_beeston_full, -# ) -# parameters, staterrors = auto.prepare() - -# # barlow-beeston-lite -# auto2 = evm.autostaterrors( -# sumw=sumw, -# sumw2=sumw2, -# threshold=10.0, -# mode=evm.autostaterrors.Mode.barlow_beeston_lite, -# ) -# parameters2, staterrors2 = auto2.prepare() - -# # materialize: -# process = "signal" -# pkey = auto.key_template.format(process=process) -# modify = staterrors[pkey](parameters[pkey]) -# modified_process = modify(sumw[process]) -# """ -# import equinox as eqx - -# parameters: dict[str, dict[str, Parameter]] = {} -# staterrors: dict[str, dict[str, eqx.Partial]] = {} - -# for process, _sumw in self.sumw.items(): -# key = self.key_template.format(process=process) -# process_parameters = parameters[key] = {} -# mask = self.masks[process] -# for i in range(len(_sumw)): -# pkey = f"{process}_{i}" -# if self.mode == self.Mode.barlow_beeston_lite and not mask[i]: -# # we merge all processes into one parameter -# # for the barlow-beeston-lite approach where -# # the bin content is above a certain threshold -# pkey = f"{i}" -# process_parameters[pkey] = Parameter(value=jnp.array(0.0)) -# # prepare staterror -# kwargs = { -# "sumw": _sumw, -# "sumw2": self.sumw2[process], -# "threshold": self.threshold, -# } -# if self.mode == self.Mode.barlow_beeston_full: -# kwargs["threshold"] = jnp.inf # inf -> always poisson -# elif self.mode == self.Mode.barlow_beeston_lite: -# kwargs["sumw"] = jnp.where( -# mask, -# _sumw, -# sum(jax.tree_util.tree_leaves(self.sumw)), -# ) -# kwargs["sumw2"] = jnp.where( -# mask, -# self.sumw2[process], -# sum(jax.tree_util.tree_leaves(self.sumw2)), -# ) -# staterrors[key] = eqx.Partial(staterror, **kwargs) -# return parameters, staterrors +Composable = Modifier | compose | where diff --git a/src/evermore/parameter.py b/src/evermore/parameter.py index 76fce2b..44dbd19 100644 --- a/src/evermore/parameter.py +++ b/src/evermore/parameter.py @@ -1,19 +1,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import equinox as eqx import jax.numpy as jnp -from jaxtyping import Array, ArrayLike, Float +import jax.tree_util as jtu +from jaxtyping import Array, ArrayLike, Float, PyTree -from evermore.pdf import HashablePDF +from evermore.custom_types import Sentinel, _NoValue +from evermore.pdf import PDF from evermore.util import as1darray if TYPE_CHECKING: - from evermore.modifier import modifier + from evermore.modifier import Modifier __all__ = [ "Parameter", + "staterrors", "auto_init", ] @@ -26,23 +29,38 @@ class Parameter(eqx.Module): value: Array = eqx.field(converter=as1darray) lower: Array = eqx.field(static=True, converter=as1darray) upper: Array = eqx.field(static=True, converter=as1darray) - constraints: set[HashablePDF] = eqx.field(static=True) + constraint: PDF | Sentinel = eqx.field(static=True) def __init__( self, value: ArrayLike = 0.0, lower: ArrayLike = -jnp.inf, upper: ArrayLike = jnp.inf, + constraint: PDF | Sentinel = _NoValue, ) -> None: self.value = as1darray(value) self.lower = as1darray(lower) self.upper = as1darray(upper) - self.constraints: set[HashablePDF] = set() - - def update(self, value: Array | Parameter) -> Parameter: - if isinstance(value, Parameter): - value = value.value - return eqx.tree_at(lambda t: t.value, self, value) + self.constraint = constraint + + def _set_constraint(self, constraint: PDF, overwrite: bool = False) -> PDF: + # Frozen dataclasses don't support setting attributes so we have to + # overload that operation here as they do in the dataclass implementation + assert isinstance(constraint, PDF) + + # If no constraint is set or overwriting is allowed, set it and return. + if self.constraint is _NoValue or overwrite: + object.__setattr__(self, "constraint", constraint) + return constraint + + # Check if new constraint is compatible by class only, otherwise complain. + # This is ok because we know that the constraints from evm.modifiers + # will always be compatible within the same class (underlying arrays are equal by construction). + # This significantly speeds up this check. + if self.constraint.__class__ is not constraint.__class__: + msg = f"Parameter constraint '{self.constraint}' is different than the constraint {constraint} to be added." + raise ValueError(msg) + return cast(PDF, self.constraint) @property def boundary_penalty(self) -> Array: @@ -53,30 +71,57 @@ def boundary_penalty(self) -> Array: ) # shorthands - def unconstrained(self) -> modifier: + def unconstrained(self) -> Modifier: import evermore as evm - return evm.modifier(parameter=self, effect=evm.effect.unconstrained()) + return evm.Modifier(parameter=self, effect=evm.effect.unconstrained()) - def gauss(self, width: Array) -> modifier: + def gauss(self, width: Array) -> Modifier: import evermore as evm - return evm.modifier(parameter=self, effect=evm.effect.gauss(width=width)) + return evm.Modifier(parameter=self, effect=evm.effect.gauss(width=width)) - def lnN(self, width: Float[Array, 2]) -> modifier: + def lnN(self, width: Float[Array, 2]) -> Modifier: import evermore as evm - return evm.modifier(parameter=self, effect=evm.effect.lnN(width=width)) + return evm.Modifier(parameter=self, effect=evm.effect.lnN(width=width)) - def poisson(self, lamb: Array) -> modifier: + def poisson(self, lamb: Array) -> Modifier: import evermore as evm - return evm.modifier(parameter=self, effect=evm.effect.poisson(lamb=lamb)) + return evm.Modifier(parameter=self, effect=evm.effect.poisson(lamb=lamb)) + + def shape(self, up: Array, down: Array) -> Modifier: + import evermore as evm + + return evm.Modifier(parameter=self, effect=evm.effect.shape(up=up, down=down)) + - def shape(self, up: Array, down: Array) -> modifier: +def staterrors(hists: PyTree[Array]) -> PyTree[Parameter]: + """ + Create staterror (barlow-beeston) parameters. + + Example: + + .. code-block:: python + + import jax.numpy as jnp import evermore as evm - return evm.modifier(parameter=self, effect=evm.effect.shape(up=up, down=down)) + hists = {"qcd": jnp.array([1, 2, 3]), "dy": jnp.array([4, 5, 6])} + + # bulk create staterrors + staterrors = evm.parameter.staterrors(hists=hists) + """ + + leaves = jtu.tree_leaves(hists) + # create parameters + return { + # per process and bin + "poisson": jtu.tree_map(lambda _: Parameter(value=0.0), hists), + # only per bin + "gauss": Parameter(value=jnp.zeros_like(leaves[0])), + } def auto_init(module: eqx.Module) -> eqx.Module: diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index f2d4542..3064cc7 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import equinox as eqx import jax @@ -9,10 +9,10 @@ from jaxtyping import Array, PRNGKeyArray if TYPE_CHECKING: - from evermore import Parameter + pass __all__ = [ - "HashablePDF", + "PDF", "Flat", "Gauss", "Poisson", @@ -23,11 +23,7 @@ def __dir__(): return __all__ -class HashablePDF(eqx.Module): - @abstractmethod - def __hash__(self) -> int: - ... - +class PDF(eqx.Module): @abstractmethod def logpdf(self, x: Array) -> Array: ... @@ -41,18 +37,11 @@ def cdf(self, x: Array) -> Array: ... @abstractmethod - def inv_cdf(self, x: Array) -> Array: - ... - - @abstractmethod - def sample(self, key: PRNGKeyArray, parameter: Parameter) -> Array: + def sample(self, key: PRNGKeyArray) -> Array: ... -class Flat(HashablePDF): - def __hash__(self): - return hash(self.__class__) - +class Flat(PDF): def logpdf(self, x: Array) -> Array: return jnp.array([0.0]) @@ -62,31 +51,18 @@ def pdf(self, x: Array) -> Array: def cdf(self, x: Array) -> Array: return jnp.array([1.0]) - def inv_cdf(self, x: Array) -> Array: - msg = "Flat distribution has no inverse CDF." - raise ValueError(msg) - - def sample(self, key: PRNGKeyArray, parameter: Parameter) -> Array: - return jax.random.uniform( - key, - parameter.value.shape, - # what should be the ranges? - # +/-jnp.inf leads to nans... - # minval=parameter.lower, - # maxval=parameter.upper, - ) - - -class Gauss(HashablePDF): - mean: float = eqx.field(static=True) - width: float = eqx.field(static=True) + def sample(self, key: PRNGKeyArray) -> Array: + # sample parameter from pdf + # what should be the ranges? + # +/-jnp.inf leads to nans... + # minval=??, + # maxval=??, + return jax.random.uniform(key) - def __init__(self, mean: float, width: float) -> None: - self.mean = mean - self.width = width - def __hash__(self): - return hash(self.__class__) ^ hash((self.mean, self.width)) +class Gauss(PDF): + mean: Array = eqx.field(static=True) + width: Array = eqx.field(static=True) def logpdf(self, x: Array) -> Array: logpdf_max = jax.scipy.stats.norm.logpdf( @@ -101,34 +77,14 @@ def pdf(self, x: Array) -> Array: def cdf(self, x: Array) -> Array: return jax.scipy.stats.norm.cdf(x, loc=self.mean, scale=self.width) - def inv_cdf(self, x: Array) -> Array: - return jax.scipy.stats.norm.ppf(x, loc=self.mean, scale=self.width) - - def sample(self, key: PRNGKeyArray, parameter: Parameter) -> Array: - return self.mean + self.width * jax.random.normal( - key, - shape=parameter.value.shape, - dtype=parameter.value.dtype, - ) + def sample(self, key: PRNGKeyArray) -> Array: + # sample parameter from pdf + return self.mean + self.width * jax.random.normal(key) -class Poisson(HashablePDF): +class Poisson(PDF): lamb: Array = eqx.field(static=True) - def __init__(self, lamb: Array) -> None: - self.lamb = lamb - - def __hash__(self): - return hash(self.__class__) - - def __eq__(self, other: Any): # type: ignore[override] - if not isinstance(other, Poisson): - return ValueError(f"Cannot compare Poisson with {type(other)}") - # We need to implement __eq__ explicitely because we have a non-hashable field (lamb). - # Implementing __eq__ is necessary for the `==` operator to work and to ensure that the - # Poisson distribution is correctly added to a python set. - return jnp.all(self.lamb == other.lamb) - def logpdf(self, x: Array) -> Array: logpdf_max = jax.scipy.stats.poisson.logpmf(self.lamb, mu=self.lamb) unnormalized = jax.scipy.stats.poisson.logpmf((x + 1) * self.lamb, mu=self.lamb) @@ -140,26 +96,10 @@ def pdf(self, x: Array) -> Array: def cdf(self, x: Array) -> Array: return jax.scipy.stats.poisson.cdf((x + 1) * self.lamb, mu=self.lamb) - def inv_cdf(self, x: Array) -> Array: - # see: https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html?highlight=poisson%20inverse#5.3-Example:-Left-truncated-Poisson - def cond_fn(val): - n, cdf = val - return jnp.any(cdf < x) - - def body_fn(val): - n, cdf = val - n_new = jnp.where(cdf < x, n + 1, n) - return n_new, self.cdf(n_new) - - start = jnp.zeros_like(x) - cdf_start = self.cdf(start) - n, _ = jax.lax.while_loop(cond_fn, body_fn, (start, cdf_start)) - return n.astype(jnp.result_type(int)) - - def sample(self, key: PRNGKeyArray) -> Array: # type: ignore[override] - return jax.random.poisson( - key, - self.lamb, - shape=self.lamb.shape, - dtype=self.lamb.dtype, - ) + def sample(self, key: PRNGKeyArray) -> Array: + # sample parameter from pdf + # some problems with this: + # - this samples only integers, do we want that? + # - this breaks for 0 in self.lamb + # - if jax.random.poisson(key, self.lamb) == 0 then what do we know about the parameter? + return (jax.random.poisson(key, self.lamb) / self.lamb) - 1 diff --git a/src/evermore/sample.py b/src/evermore/sample.py index 5ea6632..5487406 100644 --- a/src/evermore/sample.py +++ b/src/evermore/sample.py @@ -1,9 +1,12 @@ from collections.abc import Callable +from typing import cast import equinox as eqx import jax from jaxtyping import Array, PRNGKeyArray, PyTree +from evermore.custom_types import _NoValue +from evermore.pdf import PDF from evermore.util import is_parameter @@ -19,16 +22,15 @@ def toy_module(module: eqx.Module, key: PRNGKeyArray) -> PyTree[Callable]: keys_tree = jax.tree_util.tree_unflatten(params_structure, keys) def _sample(param: Parameter, key: Parameter) -> Array: - if not param.constraints: - msg = f"Parameter {param} has no constraint pdf, can't sample from it. Maybe you need to call the model once to populate all constraints?" + if param.constraint is _NoValue: + msg = f"Parameter {param} has no constraint pdf, can't sample from it." raise RuntimeError(msg) - if len(param.constraints) > 1: - msg = f"More than one constraint per parameter is not allowed. Got: {param.constraints}" - raise ValueError(msg) - pdf = next(iter(param.constraints)) + + pdf = param.constraint + pdf = cast(PDF, pdf) # sample new value from the constraint pdf - sampled_param_value = pdf.sample(key.value, param) + sampled_param_value = pdf.sample(key.value) # replace the sampled parameter value and return new parameter return eqx.tree_at(lambda p: p.value, param, sampled_param_value) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 987e66c..14f8a59 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -4,57 +4,49 @@ import pytest import evermore as evm +from evermore.custom_types import _NoValue from evermore.pdf import Flat, Gauss, Poisson def test_parameter(): p = evm.Parameter(value=jnp.array(1.0), lower=jnp.array(0.0), upper=jnp.array(2.0)) - assert p.value == 1.0 - assert p.update(jnp.array(2.0)).value == 2.0 assert p.lower == 0.0 assert p.upper == 2.0 - assert p.boundary_penalty == 0.0 - assert p.update(jnp.array(3.0)).boundary_penalty == jnp.inf + assert p.constraint is _NoValue def test_unconstrained(): p = evm.Parameter(value=jnp.array(1.0)) u = evm.effect.unconstrained() - assert u.constraint == Flat() + assert u.constraint(p) == Flat() assert u.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) - assert u.scale_factor(p.update(jnp.array(2.0)), jnp.array(1.0)) == pytest.approx( - 2.0 - ) def test_gauss(): p = evm.Parameter(value=jnp.array(0.0)) g = evm.effect.gauss(width=jnp.array(1.0)) - assert g.constraint == Gauss(mean=0.0, width=1.0) + assert g.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) assert g.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) - assert g.scale_factor(p.update(jnp.array(2.0)), jnp.array(1.0)) == pytest.approx( - 3.0 - ) def test_lnN(): p = evm.Parameter(value=jnp.array(0.0)) ln = evm.effect.lnN(width=jnp.array([0.9, 1.1])) - assert ln.constraint == Gauss(mean=0.0, width=1.0) + assert ln.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) assert ln.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) # assert ln.scale_factor(p.update(jnp.array(2.0)), jnp.array(1.0)) == pytest.approx(1.1) def test_poisson(): - # p = evm.Parameter(value=jnp.array(0.0)) + p = evm.Parameter(value=jnp.array(0.0)) po = evm.effect.poisson(lamb=jnp.array(10)) - assert po.constraint == Poisson(lamb=jnp.array(10)) + assert po.constraint(p) == Poisson(lamb=jnp.array(10)) # assert po.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) # FIXME # assert po.scale_factor(p.update(jnp.array(2.0)), jnp.array(1.0)) == pytest.approx(1.1) # FIXME @@ -68,26 +60,26 @@ def test_modifier(): norm = evm.Parameter(value=jnp.array(0.0)) # unconstrained effect - m_unconstrained = evm.modifier(parameter=mu, effect=evm.effect.unconstrained()) + m_unconstrained = evm.Modifier(parameter=mu, effect=evm.effect.unconstrained()) assert m_unconstrained(jnp.array([10])) == pytest.approx(11) # gauss effect - m_gauss = evm.modifier(parameter=norm, effect=evm.effect.gauss(jnp.array(0.1))) + m_gauss = evm.Modifier(parameter=norm, effect=evm.effect.gauss(jnp.array(0.1))) assert m_gauss(jnp.array([10])) == pytest.approx(10) # lnN effect - m_lnN = evm.modifier( + m_lnN = evm.Modifier( parameter=norm, effect=evm.effect.lnN(width=jnp.array([0.9, 1.1])) ) assert m_lnN(jnp.array([10])) == pytest.approx(10) # poisson effect # FIXME - # m_poisson = modifier(name="norm", parameter=norm, effect=poisson(jnp.array(10))) + # m_poisson = Modifier(name="norm", parameter=norm, effect=poisson(jnp.array(10))) # assert m_poisson(jnp.array(10)) == pytest.approx(10) # shape effect # FIXME # effect = shape(up=jnp.array(12), down=jnp.array(8)) - # m_shape = modifier(name="norm", parameter=norm, effect=effect) + # m_shape = Modifier(name="norm", parameter=norm, effect=effect) # assert m_shape(jnp.array(10)) == pytest.approx(10) @@ -96,12 +88,12 @@ def test_compose(): norm = evm.Parameter(value=jnp.array(0.0)) # unconstrained effect - m_unconstrained = evm.modifier(parameter=mu, effect=evm.effect.unconstrained()) + m_unconstrained = evm.Modifier(parameter=mu, effect=evm.effect.unconstrained()) # gauss effect - m_gauss = evm.modifier(parameter=norm, effect=evm.effect.gauss(jnp.array(0.1))) + m_gauss = evm.Modifier(parameter=norm, effect=evm.effect.gauss(jnp.array(0.1))) # compose - m = evm.compose(m_unconstrained, m_gauss) + m = evm.modifier.compose(m_unconstrained, m_gauss) assert len(m) == 2 assert m(jnp.array([10])) == pytest.approx(11) diff --git a/tests/test_pdf.py b/tests/test_pdf.py index 1e83ca8..da55582 100644 --- a/tests/test_pdf.py +++ b/tests/test_pdf.py @@ -19,13 +19,11 @@ def test_flat(): def test_gauss(): - pdf = Gauss(mean=0.0, width=1.0) + pdf = Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) assert pdf.pdf(jnp.array(0.0)) == pytest.approx(1.0 / jnp.sqrt(2 * jnp.pi)) assert pdf.logpdf(jnp.array(0.0)) == pytest.approx(0.0) assert pdf.cdf(jnp.array(0.0)) == pytest.approx(0.5) - assert pdf.inv_cdf(jnp.array(0.5)) == pytest.approx(0.0) - assert pdf.inv_cdf(pdf.cdf(jnp.array(0.0))) == pytest.approx(0.0) def test_poisson(): @@ -34,11 +32,11 @@ def test_poisson(): assert pdf.pdf(jnp.array(0)) == pytest.approx(0.12510978) assert pdf.logpdf(jnp.array(-0.5)) == pytest.approx(-1.196003) assert pdf.cdf(jnp.array(0)) == pytest.approx(0.5830412) - # assert pdf.inv_cdf(jnp.array(0.5830412)) == pytest.approx(10) - # assert pdf.inv_cdf(pdf.cdf(jnp.array(10))) == pytest.approx(10) def test_hashable(): assert hash(Flat()) == hash(Flat()) - assert hash(Gauss(mean=0.0, width=1.0)) == hash(Gauss(mean=0.0, width=1.0)) + assert hash(Gauss(mean=jnp.array(0.0), width=jnp.array(1.0))) == hash( + Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) + ) assert hash(Poisson(lamb=jnp.array(10))) == hash(Poisson(lamb=jnp.array(10))) From ab67191279dd055fb8dd32a1ab448d12b5bdf0fc Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Fri, 8 Mar 2024 15:55:14 +0100 Subject: [PATCH 12/22] disable example testing and code cov upload --- .github/workflows/ci.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 45adb7e..8df28b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,9 +51,9 @@ jobs: python -m pytest -ra --cov --cov-report=xml --cov-report=term --durations=20 - - name: Upload coverage report - uses: codecov/codecov-action@v4.0.2 + # - name: Upload coverage report + # uses: codecov/codecov-action@v4.0.2 - - name: Test examples - run: >- - for f in examples/*.py; do echo "run $f" && python "$f"; done + # - name: Test examples + # run: >- + # for f in examples/*.py; do echo "run $f" && python "$f"; done From e5f12a41ea386b0788af52c99cfe172cd34d6cc5 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 11:23:14 +0100 Subject: [PATCH 13/22] streamline sf calculations; now it's easier to add custom modifiers --- pixi.lock | 140 ++++++++++++++++++++++++++++++++--- pixi.toml | 1 + src/evermore/custom_types.py | 37 ++++++++- src/evermore/effect.py | 58 +++++++-------- src/evermore/modifier.py | 136 +++++++++++----------------------- src/evermore/util.py | 15 ++-- tests/test_parameter.py | 18 ++++- 7 files changed, 262 insertions(+), 143 deletions(-) diff --git a/pixi.lock b/pixi.lock index 2686974..24b3da6 100644 --- a/pixi.lock +++ b/pixi.lock @@ -148,6 +148,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/decorator-5.1.1-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/equinox-0.11.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/executing-2.0.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2 @@ -163,8 +164,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-7.0.1-hd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ipython-8.22.1-pyh707e725_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jax-0.4.23-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/jaxtyping-0.2.28-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jedi-0.19.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.3-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/lineax-0.0.4-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/mdit-py-plugins-0.4.0-pyhd8ed1ab_0.conda @@ -175,6 +178,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/myst-parser-2.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt-einsum-3.3.0-hd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt_einsum-3.3.0-pyhc1e730c_2.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/optimistix-0.0.6-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/parso-0.8.3-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_0.conda @@ -204,7 +208,9 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/tomli-2.0.1-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.1-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.9.0-pyha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/typeguard-2.13.3-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.10.0-hd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.10.0-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2024a-h0c530f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/uhi-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.2.1-pyhd8ed1ab_0.conda @@ -220,6 +226,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/decorator-5.1.1-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/equinox-0.11.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/executing-2.0.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/idna-3.6-pyhd8ed1ab_0.conda @@ -229,8 +236,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-7.0.1-hd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ipython-8.22.1-pyh707e725_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jax-0.4.23-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/jaxtyping-0.2.28-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jedi-0.19.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.3-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/lineax-0.0.4-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/mdit-py-plugins-0.4.0-pyhd8ed1ab_0.conda @@ -241,6 +250,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/myst-parser-2.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt-einsum-3.3.0-hd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt_einsum-3.3.0-pyhc1e730c_2.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/optimistix-0.0.6-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/parso-0.8.3-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_0.conda @@ -267,7 +277,9 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/stack_data-0.6.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.1-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.9.0-pyha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/typeguard-2.13.3-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.10.0-hd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.10.0-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2024a-h0c530f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/uhi-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.2.1-pyhd8ed1ab_0.conda @@ -344,6 +356,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/decorator-5.1.1-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/equinox-0.11.3-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/executing-2.0.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/idna-3.6-pyhd8ed1ab_0.conda @@ -353,8 +366,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/importlib_metadata-7.0.1-hd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ipython-8.22.1-pyh707e725_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jax-0.4.23-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/jaxtyping-0.2.28-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jedi-0.19.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.3-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/lineax-0.0.4-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.6-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/mdit-py-plugins-0.4.0-pyhd8ed1ab_0.conda @@ -365,6 +380,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/myst-parser-2.0.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt-einsum-3.3.0-hd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/opt_einsum-3.3.0-pyhc1e730c_2.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/optimistix-0.0.6-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-23.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/parso-0.8.3-pyhd8ed1ab_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_0.conda @@ -391,7 +407,9 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/stack_data-0.6.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/traitlets-5.14.1-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.9.0-pyha770c72_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/typeguard-2.13.3-pyhd8ed1ab_0.tar.bz2 + - conda: https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.10.0-hd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.10.0-pyha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2024a-h0c530f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/uhi-0.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.2.1-pyhd8ed1ab_0.conda @@ -1060,6 +1078,24 @@ packages: license: CC-PDDC AND BSD-3-Clause AND BSD-2-Clause AND ZPL-2.1 size: 919457 timestamp: 1701883162608 +- kind: conda + name: equinox + version: 0.11.3 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/equinox-0.11.3-pyhd8ed1ab_0.conda + sha256: bdd6efc3f24ae506de67b9b0e33e470423420784468bd5a5c99b24a337dbf9c8 + md5: dc0d62716985faa6f2733c816d9bb31f + depends: + - jax >=0.4.13 + - jaxtyping >=0.2.20 + - python >=3.9 + - typing-extensions >=4.5.0 + license: Apache-2.0 + license_family: APACHE + size: 123228 + timestamp: 1705016222774 - kind: conda name: exceptiongroup version: 1.2.0 @@ -1667,6 +1703,23 @@ packages: license_family: APACHE size: 54820454 timestamp: 1707874036386 +- kind: conda + name: jaxtyping + version: 0.2.28 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/jaxtyping-0.2.28-pyhd8ed1ab_0.conda + sha256: 6a5f5ce1fecee35e7fc6596943558f6a8e4f98131b4a90a9a9965692a57eca5a + md5: cb58cd9674126626a66e2b74a1859158 + depends: + - numpy >=1.20.0 + - python >=3.9 + - typeguard >=2.13.3,<3 + - typing-extensions >=3.7.4.1 + license: MIT + size: 35332 + timestamp: 1709843563539 - kind: conda name: jedi version: 0.19.1 @@ -3467,6 +3520,25 @@ packages: license_family: Other size: 61588 timestamp: 1686575217516 +- kind: conda + name: lineax + version: 0.0.4 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/lineax-0.0.4-pyhd8ed1ab_0.conda + sha256: 3aa0c5dd8a6dc92968d2d6520de140fbec23d18a89477168c6f826a9f30e288b + md5: 4fc3c2e0dd97b0f6a0bf21c131ca1035 + depends: + - equinox >=0.11.0 + - jax >=0.4.13 + - jaxtyping >=0.2.20 + - python >=3.9,<4.dev0 + - typing-extensions >=4.5.0 + license: Apache-2.0 + license_family: APACHE + size: 46961 + timestamp: 1707348622614 - kind: conda name: llvm-openmp version: 17.0.6 @@ -4209,6 +4281,26 @@ packages: license_family: MIT size: 58004 timestamp: 1696449058916 +- kind: conda + name: optimistix + version: 0.0.6 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/optimistix-0.0.6-pyhd8ed1ab_0.conda + sha256: 8a25a6c9ae84b6f28211044cad6cca1e4eb55386c644dd58f9911b093dcc7efb + md5: 34c18187c3cc9fb17a0cf6f3ee9523d1 + depends: + - equinox >=0.11.1 + - jax >=0.4.18 + - jaxtyping >=0.2.23 + - lineax >=0.0.4 + - python >=3.9 + - typing-extensions >=4.5.0 + license: Apache-2.0 + license_family: APACHE + size: 54985 + timestamp: 1707448935801 - kind: conda name: packaging version: '23.2' @@ -5402,21 +5494,51 @@ packages: license_family: BSD size: 110329 timestamp: 1704213177224 +- kind: conda + name: typeguard + version: 2.13.3 + build: pyhd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/typeguard-2.13.3-pyhd8ed1ab_0.tar.bz2 + sha256: 5e6a3b92c4f99b7b51c646669ad9dbffd307a788fda0001dc82476de8e67ad67 + md5: af104e581c40c72813864454962e1795 + depends: + - python >=3.4 + license: MIT + license_family: MIT + size: 19655 + timestamp: 1658932205516 +- kind: conda + name: typing-extensions + version: 4.10.0 + build: hd8ed1ab_0 + subdir: noarch + noarch: python + url: https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.10.0-hd8ed1ab_0.conda + sha256: 0698fe2c4e555fb44c27c60f7a21fa0eea7f5bf8186ad109543c5b056e27f96a + md5: 091683b9150d2ebaa62fd7e2c86433da + depends: + - typing_extensions 4.10.0 pyha770c72_0 + license: PSF-2.0 + license_family: PSF + size: 10181 + timestamp: 1708904805365 - kind: conda name: typing_extensions - version: 4.9.0 + version: 4.10.0 build: pyha770c72_0 subdir: noarch noarch: python - url: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.9.0-pyha770c72_0.conda - sha256: f3c5be8673bfd905c4665efcb27fa50192f24f84fa8eff2f19cba5d09753d905 - md5: a92a6440c3fe7052d63244f3aba2a4a7 + url: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.10.0-pyha770c72_0.conda + sha256: 4be24d557897b2f6609f5d5f7c437833c62f4d4a96581e39530067e96a2d0451 + md5: 16ae769069b380646c47142d719ef466 depends: - python >=3.8 license: PSF-2.0 license_family: PSF - size: 36058 - timestamp: 1702176292645 + size: 37018 + timestamp: 1708904796013 - kind: conda name: tzdata version: 2024a diff --git a/pixi.toml b/pixi.toml index f648bba..175feb0 100644 --- a/pixi.toml +++ b/pixi.toml @@ -19,6 +19,7 @@ myst-parser = "*" matplotlib = "*" mplhep = "*" imageio = "*" +optimistix = "*" [host-dependencies] pip = "*" diff --git a/src/evermore/custom_types.py b/src/evermore/custom_types.py index 737ec9d..4c2c27c 100644 --- a/src/evermore/custom_types.py +++ b/src/evermore/custom_types.py @@ -1,9 +1,16 @@ +from __future__ import annotations + from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +from jaxtyping import Array + +if TYPE_CHECKING: + from evermore.modifier import compose -import jaxtyping -AddOrMul = Callable[[jaxtyping.ArrayLike, jaxtyping.ArrayLike], jaxtyping.Array] +AddOrMul = Callable[[Array, Array], Array] +AddOrMulSFs = dict[AddOrMul, Array] class Sentinel: @@ -19,3 +26,27 @@ def __repr__(self) -> str: _NoValue: Any = Sentinel("") + + +@runtime_checkable +class ModifierLike(Protocol): + def scale_factor(self, sumw: Array) -> AddOrMulSFs: + """ + Always return a dictionary of scale factors for the sumw array. + Dictionary has to look as follows: + + .. code-block:: python + + import operator + from jaxtyping import Array + + + {operator.mul: Array, operator.add: Array} + """ + ... + + def __call__(self, sumw: Array) -> Array: + ... + + def __matmul__(self, other: ModifierLike) -> compose: + ... diff --git a/src/evermore/effect.py b/src/evermore/effect.py index 6256a9d..8d348b4 100644 --- a/src/evermore/effect.py +++ b/src/evermore/effect.py @@ -1,20 +1,20 @@ import abc import operator -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING import equinox as eqx import jax.numpy as jnp from jaxtyping import Array, Float -from evermore.custom_types import AddOrMul +from evermore.custom_types import AddOrMulSFs from evermore.parameter import Parameter from evermore.pdf import PDF, Flat, Gauss, Poisson -from evermore.util import as1darray +from evermore.util import as1darray, initSF if TYPE_CHECKING: - from typing import ClassVar as AbstractClassVar + pass else: - from equinox import AbstractClassVar + pass __all__ = [ @@ -32,25 +32,23 @@ def __dir__(): class Effect(eqx.Module): - apply_op: AbstractClassVar[AddOrMul] - @abc.abstractmethod def constraint(self, parameter: Parameter) -> PDF: ... @abc.abstractmethod - def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: ... class unconstrained(Effect): - apply_op: ClassVar[AddOrMul] = operator.mul - def constraint(self, parameter: Parameter) -> PDF: return Flat() - def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: - return parameter.value + def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: + sf = initSF(shape=parameter.value.shape) + sf[operator.mul] = parameter.value + return sf DEFAULT_EFFECT = unconstrained() @@ -59,8 +57,6 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: class gauss(Effect): width: Array = eqx.field(static=True, converter=as1darray) - apply_op: ClassVar[AddOrMul] = operator.mul - def __init__(self, width: Array) -> None: self.width = width @@ -69,7 +65,7 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: """ Implementation with (inverse) CDFs is defined as follows: @@ -87,15 +83,15 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: return (parameter.value * self.width) + 1 """ - return (parameter.value * self.width) + 1 + sf = initSF(shape=parameter.value.shape) + sf[operator.mul] = (parameter.value * self.width) + 1 + return sf class shape(Effect): up: Array = eqx.field(converter=as1darray) down: Array = eqx.field(converter=as1darray) - apply_op: ClassVar[AddOrMul] = operator.add - def __init__( self, up: Array, @@ -128,9 +124,11 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: - sf = parameter.value - return self.vshift(sf=sf, sumw=sumw) + def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: + p = parameter.value + sf = initSF(shape=p.shape) + sf[operator.add] = self.vshift(sf=p, sumw=sumw) + return sf # shift = self.vshift(sf=sf, sumw=sumw) # # handle zeros, see: https://github.com/google/jax/issues/5039 # x = jnp.where(sumw == 0.0, 1.0, sumw) @@ -140,8 +138,6 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: class lnN(Effect): width: Float[Array, "2"] = eqx.field(static=True) - apply_op: ClassVar[AddOrMul] = operator.mul - def __init__( self, width: Float[Array, "2"], # given as (down, up) @@ -170,7 +166,7 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: + def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: """ Implementation with (inverse) CDFs is defined as follows: @@ -188,14 +184,16 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: return jnp.exp(parameter.value * self.interpolate(parameter=parameter)) """ - return jnp.exp(parameter.value * self.interpolate(parameter=parameter)) + sf = initSF(shape=parameter.value.shape) + sf[operator.mul] = jnp.exp( + parameter.value * self.interpolate(parameter=parameter) + ) + return sf class poisson(Effect): lamb: Array = eqx.field(static=True, converter=as1darray) - apply_op: ClassVar[AddOrMul] = operator.mul - def __init__(self, lamb: Array) -> None: self.lamb = lamb @@ -203,5 +201,7 @@ def constraint(self, parameter: Parameter) -> PDF: assert parameter.value.shape == self.lamb.shape return Poisson(lamb=self.lamb) - def scale_factor(self, parameter: Parameter, sumw: Array) -> Array: - return parameter.value + 1 + def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: + sf = initSF(shape=parameter.value.shape) + sf[operator.add] = parameter.value + 1 + return sf diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index f8eee16..0e01476 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -2,18 +2,17 @@ import operator from functools import reduce -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import equinox as eqx import jax import jax.numpy as jnp from jaxtyping import Array -from evermore.custom_types import AddOrMul -from evermore.effect import ( - DEFAULT_EFFECT, -) +from evermore.custom_types import AddOrMul, AddOrMulSFs, ModifierLike +from evermore.effect import DEFAULT_EFFECT from evermore.parameter import Parameter +from evermore.util import initSF if TYPE_CHECKING: from evermore.effect import Effect @@ -29,7 +28,15 @@ def __dir__(): return __all__ -class Modifier(eqx.Module): +class ApplyFn(eqx.Module): + @jax.named_scope("evm.modifier.ApplyFn") + def __call__(self: ModifierLike, sumw: Array) -> Array: + sf = self.scale_factor(sumw=sumw) + # apply + return sf[operator.mul] * (sumw + sf[operator.add]) + + +class Modifier(ApplyFn): """ Create a new modifier for a given parameter and penalty. @@ -82,21 +89,14 @@ def __init__(self, parameter: Parameter, effect: Effect = DEFAULT_EFFECT) -> Non constraint = self.effect.constraint(parameter=self.parameter) self.parameter._set_constraint(constraint, overwrite=False) - def scale_factor(self, sumw: Array) -> Array: + def scale_factor(self, sumw: Array) -> AddOrMulSFs: return self.effect.scale_factor(parameter=self.parameter, sumw=sumw) - @jax.named_scope("evm.modifier") - def __call__(self, sumw: Array) -> Array: - op = self.effect.apply_op - shift = jnp.atleast_1d(self.scale_factor(sumw=sumw)) - shift = jnp.broadcast_to(shift, sumw.shape) - return op(shift, sumw) # type: ignore[call-arg] - - def __matmul__(self, other: Composable) -> compose: + def __matmul__(self, other: ModifierLike) -> compose: return compose(self, other) -class where(eqx.Module): +class where(ApplyFn): """ Combine two modifiers based on a condition. @@ -125,29 +125,21 @@ class where(eqx.Module): modifier_true: Modifier modifier_false: Modifier - def scale_factor(self, sumw: Array) -> Array: - return jnp.where( - self.condition, - self.modifier_true.scale_factor(sumw), - self.modifier_false.scale_factor(sumw), - ) + def scale_factor(self, sumw: Array) -> AddOrMulSFs: + sf = initSF(shape=sumw.shape) - @jax.named_scope("evm.where") - def __call__(self, sumw: Array) -> Array: - op_true = self.modifier_true.effect.apply_op - op_false = self.modifier_false.effect.apply_op - sf = self.scale_factor(sumw=sumw) - return jnp.where( - self.condition, - op_true(jnp.atleast_1d(sf), sumw), # type: ignore[call-arg] - op_false(jnp.atleast_1d(sf), sumw), # type: ignore[call-arg] - ) + true_sf = self.modifier_true.scale_factor(sumw) + false_sf = self.modifier_false.scale_factor(sumw) + + for op in operator.mul, operator.add: + sf.update(jnp.where(self.condition, true_sf[op], false_sf[op])) + return sf - def __matmul__(self, other: Composable) -> compose: + def __matmul__(self, other: ModifierLike) -> compose: return compose(self, other) -class compose(eqx.Module): +class compose(ApplyFn): """ Composition of multiple modifiers, i.e.: `(f ∘ g ∘ h)(hist) = f(hist) * g(hist) * h(hist)` It behaves like a single modifier, but it is composed of multiple modifiers; it can be arbitrarly nested. @@ -194,9 +186,9 @@ class compose(eqx.Module): eqx.filter_jit(composition)(hist) """ - modifiers: list[Composable] + modifiers: list[ModifierLike] - def __init__(self, *modifiers: Composable) -> None: + def __init__(self, *modifiers: ModifierLike) -> None: self.modifiers = list(modifiers) # unroll nested compositions _modifiers = [] @@ -204,7 +196,7 @@ def __init__(self, *modifiers: Composable) -> None: if isinstance(mod, compose): _modifiers.extend(mod.modifiers) else: - assert isinstance(mod, Modifier | where) + assert isinstance(mod, ModifierLike) _modifiers.append(mod) # by now all modifiers are either modifier or staterror self.modifiers = _modifiers @@ -212,61 +204,23 @@ def __init__(self, *modifiers: Composable) -> None: def __len__(self) -> int: return len(self.modifiers) - @jax.named_scope("evm.compose") - def __call__(self, sumw: Array) -> Array: - def _prep_shift(modifier: Modifier | where, sumw: Array) -> Array: - shift = modifier.scale_factor(sumw=sumw) - shift = jnp.atleast_1d(shift) - return jnp.broadcast_to(shift, sumw.shape) - + def scale_factor(self, sumw: Array) -> AddOrMulSFs: # collect all multiplicative and additive shifts - shifts: dict[AddOrMul, list] = {operator.mul: [], operator.add: []} + sfs: dict[AddOrMul, list] = {operator.add: [], operator.mul: []} for m in range(len(self)): mod = self.modifiers[m] - # cast to modifier | staterror, we know it is one of them - # because we unrolled nested compositions in __init__ - mod = cast(Modifier | where, mod) - sf = _prep_shift(mod, sumw) - if isinstance(mod, Modifier): - if mod.effect.apply_op is operator.mul: - shifts[operator.mul].append(sf) - elif mod.effect.apply_op is operator.add: - shifts[operator.add].append(sf) - else: - msg = f"Unsupported apply_op {mod.effect.apply_op} for Modifier {mod}. Only multiplicative and additive effects are supported." - raise ValueError(msg) - elif isinstance(mod, where): - op_true = mod.modifier_true.effect.apply_op - op_false = mod.modifier_false.effect.apply_op - # if both modifiers are multiplicative: - if op_true is operator.mul and op_false is operator.mul: - shifts[operator.mul].append(sf) - # if both modifiers are additive: - elif op_true is operator.add and op_false is operator.add: - shifts[operator.add].append(sf) - # if one is multiplicative and the other is additive: - elif op_true is operator.mul and op_false is operator.add: - _mult_sf = jnp.where(mod.condition, sf, 1.0) - _add_sf = jnp.where(mod.condition, sf, 0.0) - shifts[operator.mul].append(_mult_sf) - shifts[operator.add].append(_add_sf) - elif op_true is operator.add and op_false is operator.mul: - _mult_sf = jnp.where(mod.condition, 1.0, sf) - _add_sf = jnp.where(mod.condition, 0.0, sf) - shifts[operator.mul].append(_mult_sf) - shifts[operator.add].append(_add_sf) - else: - msg = f"Unsupported apply_op {op_true} and {op_false} for 'where' Modifier {mod}. Only multiplicative and additive effects are supported." - raise ValueError(msg) - # calculate the product with for operator.mul - _mult_fact = reduce(operator.mul, shifts[operator.mul], 1.0) - # calculate the sum for operator.add - _add_shift = reduce(operator.add, shifts[operator.add], 0.0) - # apply - return _mult_fact * (sumw + _add_shift) - - def __matmul__(self, other: Composable) -> compose: + _sf = mod.scale_factor(sumw) + for op in operator.add, operator.mul: + sfs[op].append(_sf[op]) + + sf = initSF(shape=sumw.shape) + # calculate the product with for operator.mul and operator.add + for op, init_val in ( + (operator.mul, jnp.ones_like(sumw)), + (operator.add, jnp.zeros_like(sumw)), + ): + sf[op] = reduce(op, sfs[op], init_val) + return sf + + def __matmul__(self, other: ModifierLike) -> compose: return compose(self, other) - - -Composable = Modifier | compose | where diff --git a/src/evermore/util.py b/src/evermore/util.py index db70722..b151d70 100644 --- a/src/evermore/util.py +++ b/src/evermore/util.py @@ -3,9 +3,7 @@ import operator from collections.abc import Callable from functools import partial -from typing import ( - Any, -) +from typing import Any import equinox as eqx import jax @@ -13,7 +11,10 @@ import jax.tree_util as jtu from jaxtyping import Array, ArrayLike, PyTree +from evermore.custom_types import AddOrMulSFs + __all__ = [ + "initSF", "is_parameter", "sum_leaves", "as1darray", @@ -26,16 +27,16 @@ def __dir__(): return __all__ +def initSF(shape: tuple) -> AddOrMulSFs: + return {operator.add: jnp.zeros(shape), operator.mul: jnp.ones(shape)} + + def is_parameter(leaf: Any) -> bool: from evermore import Parameter return isinstance(leaf, Parameter) -K = str -V = Any - - def _filtered_module_map( module: eqx.Module, fun: Callable, diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 14f8a59..277ecc0 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,5 +1,7 @@ from __future__ import annotations +import operator + import jax.numpy as jnp import pytest @@ -22,7 +24,10 @@ def test_unconstrained(): u = evm.effect.unconstrained() assert u.constraint(p) == Flat() - assert u.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) + assert u.scale_factor(p, jnp.array(1.0)) == { + operator.mul: jnp.array([1.0]), + operator.add: jnp.array([0.0]), + } def test_gauss(): @@ -30,7 +35,10 @@ def test_gauss(): g = evm.effect.gauss(width=jnp.array(1.0)) assert g.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) - assert g.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) + assert g.scale_factor(p, jnp.array(1.0)) == { + operator.mul: jnp.array([1.0]), + operator.add: jnp.array([0.0]), + } def test_lnN(): @@ -38,8 +46,10 @@ def test_lnN(): ln = evm.effect.lnN(width=jnp.array([0.9, 1.1])) assert ln.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) - assert ln.scale_factor(p, jnp.array(1.0)) == pytest.approx(1.0) - # assert ln.scale_factor(p.update(jnp.array(2.0)), jnp.array(1.0)) == pytest.approx(1.1) + assert ln.scale_factor(p, jnp.array(1.0)) == { + operator.mul: jnp.array([1.0]), + operator.add: jnp.array([0.0]), + } def test_poisson(): From 7413bcf73eec929894659ee2d9254e0ce0ddaa7a Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 12:57:43 +0100 Subject: [PATCH 14/22] add ModifierBase for custom modifiers and use jax compatible pytree structure for scale factors --- src/evermore/__init__.py | 3 +- src/evermore/custom_types.py | 30 +++++----- src/evermore/effect.py | 46 +++++++--------- src/evermore/modifier.py | 104 ++++++++++++++++++++++++----------- src/evermore/util.py | 7 --- tests/test_parameter.py | 25 ++++----- 6 files changed, 117 insertions(+), 98 deletions(-) diff --git a/src/evermore/__init__.py b/src/evermore/__init__.py index 9b18ef0..cf9f0ca 100644 --- a/src/evermore/__init__.py +++ b/src/evermore/__init__.py @@ -27,6 +27,7 @@ # explicitely expose some classes "Parameter", "Modifier", + "ModifierBase", ] @@ -43,5 +44,5 @@ def __dir__(): sample, util, ) -from evermore.modifier import Modifier # noqa: E402 +from evermore.modifier import Modifier, ModifierBase # noqa: E402 from evermore.parameter import Parameter # noqa: E402 diff --git a/src/evermore/custom_types.py b/src/evermore/custom_types.py index 4c2c27c..49b44db 100644 --- a/src/evermore/custom_types.py +++ b/src/evermore/custom_types.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol from jaxtyping import Array @@ -9,8 +9,19 @@ from evermore.modifier import compose +__all__ = [ + "SF", + "AddOrMul", + "ModifierLike", +] + + AddOrMul = Callable[[Array, Array], Array] -AddOrMulSFs = dict[AddOrMul, Array] + + +class SF(NamedTuple): + multiplicative: Array + additive: Array class Sentinel: @@ -28,21 +39,8 @@ def __repr__(self) -> str: _NoValue: Any = Sentinel("") -@runtime_checkable class ModifierLike(Protocol): - def scale_factor(self, sumw: Array) -> AddOrMulSFs: - """ - Always return a dictionary of scale factors for the sumw array. - Dictionary has to look as follows: - - .. code-block:: python - - import operator - from jaxtyping import Array - - - {operator.mul: Array, operator.add: Array} - """ + def scale_factor(self, sumw: Array) -> SF: ... def __call__(self, sumw: Array) -> Array: diff --git a/src/evermore/effect.py b/src/evermore/effect.py index 8d348b4..36f8834 100644 --- a/src/evermore/effect.py +++ b/src/evermore/effect.py @@ -1,15 +1,14 @@ import abc -import operator from typing import TYPE_CHECKING import equinox as eqx import jax.numpy as jnp from jaxtyping import Array, Float -from evermore.custom_types import AddOrMulSFs +from evermore.custom_types import SF from evermore.parameter import Parameter from evermore.pdf import PDF, Flat, Gauss, Poisson -from evermore.util import as1darray, initSF +from evermore.util import as1darray if TYPE_CHECKING: pass @@ -37,7 +36,7 @@ def constraint(self, parameter: Parameter) -> PDF: ... @abc.abstractmethod - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: ... @@ -45,10 +44,9 @@ class unconstrained(Effect): def constraint(self, parameter: Parameter) -> PDF: return Flat() - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: - sf = initSF(shape=parameter.value.shape) - sf[operator.mul] = parameter.value - return sf + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: + sf = jnp.broadcast_to(parameter.value, sumw.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) DEFAULT_EFFECT = unconstrained() @@ -65,7 +63,7 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: """ Implementation with (inverse) CDFs is defined as follows: @@ -83,9 +81,8 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: return (parameter.value * self.width) + 1 """ - sf = initSF(shape=parameter.value.shape) - sf[operator.mul] = (parameter.value * self.width) + 1 - return sf + sf = jnp.broadcast_to((parameter.value * self.width) + 1, sumw.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) class shape(Effect): @@ -124,11 +121,9 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: - p = parameter.value - sf = initSF(shape=p.shape) - sf[operator.add] = self.vshift(sf=p, sumw=sumw) - return sf + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: + sf = self.vshift(sf=parameter.value, sumw=sumw) + return SF(multiplicative=jnp.ones_like(sumw), additive=sf) # shift = self.vshift(sf=sf, sumw=sumw) # # handle zeros, see: https://github.com/google/jax/issues/5039 # x = jnp.where(sumw == 0.0, 1.0, sumw) @@ -166,7 +161,7 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: """ Implementation with (inverse) CDFs is defined as follows: @@ -184,11 +179,9 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: return jnp.exp(parameter.value * self.interpolate(parameter=parameter)) """ - sf = initSF(shape=parameter.value.shape) - sf[operator.mul] = jnp.exp( - parameter.value * self.interpolate(parameter=parameter) - ) - return sf + interp = self.interpolate(parameter=parameter) + sf = jnp.broadcast_to(jnp.exp(parameter.value * interp), sumw.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) class poisson(Effect): @@ -201,7 +194,6 @@ def constraint(self, parameter: Parameter) -> PDF: assert parameter.value.shape == self.lamb.shape return Poisson(lamb=self.lamb) - def scale_factor(self, parameter: Parameter, sumw: Array) -> AddOrMulSFs: - sf = initSF(shape=parameter.value.shape) - sf[operator.add] = parameter.value + 1 - return sf + def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: + sf = jnp.broadcast_to(parameter.value + 1, sumw.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index 0e01476..f95b11e 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import operator from functools import reduce from typing import TYPE_CHECKING @@ -7,17 +8,18 @@ import equinox as eqx import jax import jax.numpy as jnp +import jax.tree_util as jtu from jaxtyping import Array -from evermore.custom_types import AddOrMul, AddOrMulSFs, ModifierLike +from evermore.custom_types import SF, AddOrMul, ModifierLike from evermore.effect import DEFAULT_EFFECT from evermore.parameter import Parameter -from evermore.util import initSF if TYPE_CHECKING: from evermore.effect import Effect __all__ = [ + "ModifierBase", "Modifier", "compose", "where", @@ -28,15 +30,67 @@ def __dir__(): return __all__ +class AbstractModifier(eqx.Module): + @abc.abstractmethod + def scale_factor(self: ModifierLike, sumw: Array) -> SF: + ... + + @abc.abstractmethod + def __call__(self: ModifierLike, sumw: Array) -> Array: + ... + + @abc.abstractmethod + def __matmul__(self: ModifierLike, other: ModifierLike) -> compose: + ... + + class ApplyFn(eqx.Module): @jax.named_scope("evm.modifier.ApplyFn") def __call__(self: ModifierLike, sumw: Array) -> Array: sf = self.scale_factor(sumw=sumw) # apply - return sf[operator.mul] * (sumw + sf[operator.add]) + return sf.multiplicative * (sumw + sf.additive) + + +class MatMulCompose(eqx.Module): + def __matmul__(self: ModifierLike, other: ModifierLike) -> compose: + return compose(self, other) + + +class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier): + """ + This serves as a base class for all modifiers. + It automatically implements the __call__ method to apply the scale factors to the sumw array + and the __matmul__ method to compose two modifiers. + + Custom modifiers should inherit from this class and implement the scale_factor method. + + Example: + + .. code-block:: python + + import jax.numpy as jnp + import jax.tree_util as jtu + import evermore as evm + class clip(evm.ModifierBase): + modifier: evm.ModifierBase + min_sf: float + max_sf: float -class Modifier(ApplyFn): + def scale_factor(self, sumw: jnp.ndarray) -> evm.SF: + sf = self.modifier.scale_factor(sumw) + return jtu.tree_map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), sf) + + + parameter = evm.Parameter(value=1.1) + modifier = parameter.unconstrained() + + clipped_modifier = clip(modifier=modifier, min_sf=0.8, max_sf=1.2) + """ + + +class Modifier(ModifierBase): """ Create a new modifier for a given parameter and penalty. @@ -89,14 +143,11 @@ def __init__(self, parameter: Parameter, effect: Effect = DEFAULT_EFFECT) -> Non constraint = self.effect.constraint(parameter=self.parameter) self.parameter._set_constraint(constraint, overwrite=False) - def scale_factor(self, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, sumw: Array) -> SF: return self.effect.scale_factor(parameter=self.parameter, sumw=sumw) - def __matmul__(self, other: ModifierLike) -> compose: - return compose(self, other) - -class where(ApplyFn): +class where(ModifierBase): """ Combine two modifiers based on a condition. @@ -125,21 +176,17 @@ class where(ApplyFn): modifier_true: Modifier modifier_false: Modifier - def scale_factor(self, sumw: Array) -> AddOrMulSFs: - sf = initSF(shape=sumw.shape) - + def scale_factor(self, sumw: Array) -> SF: true_sf = self.modifier_true.scale_factor(sumw) false_sf = self.modifier_false.scale_factor(sumw) - for op in operator.mul, operator.add: - sf.update(jnp.where(self.condition, true_sf[op], false_sf[op])) - return sf + def _where(true: Array, false: Array) -> Array: + return jnp.where(self.condition, true, false) - def __matmul__(self, other: ModifierLike) -> compose: - return compose(self, other) + return jtu.tree_map(_where, true_sf, false_sf) -class compose(ApplyFn): +class compose(ModifierBase): """ Composition of multiple modifiers, i.e.: `(f ∘ g ∘ h)(hist) = f(hist) * g(hist) * h(hist)` It behaves like a single modifier, but it is composed of multiple modifiers; it can be arbitrarly nested. @@ -196,7 +243,7 @@ def __init__(self, *modifiers: ModifierLike) -> None: if isinstance(mod, compose): _modifiers.extend(mod.modifiers) else: - assert isinstance(mod, ModifierLike) + assert isinstance(mod, ModifierBase) _modifiers.append(mod) # by now all modifiers are either modifier or staterror self.modifiers = _modifiers @@ -204,23 +251,16 @@ def __init__(self, *modifiers: ModifierLike) -> None: def __len__(self) -> int: return len(self.modifiers) - def scale_factor(self, sumw: Array) -> AddOrMulSFs: + def scale_factor(self, sumw: Array) -> SF: # collect all multiplicative and additive shifts sfs: dict[AddOrMul, list] = {operator.add: [], operator.mul: []} for m in range(len(self)): mod = self.modifiers[m] _sf = mod.scale_factor(sumw) - for op in operator.add, operator.mul: - sfs[op].append(_sf[op]) + sfs[operator.mul].append(_sf.multiplicative) + sfs[operator.add].append(_sf.additive) - sf = initSF(shape=sumw.shape) # calculate the product with for operator.mul and operator.add - for op, init_val in ( - (operator.mul, jnp.ones_like(sumw)), - (operator.add, jnp.zeros_like(sumw)), - ): - sf[op] = reduce(op, sfs[op], init_val) - return sf - - def __matmul__(self, other: ModifierLike) -> compose: - return compose(self, other) + multiplicative_sf = reduce(operator.mul, sfs[operator.mul], jnp.ones_like(sumw)) + additive_sf = reduce(operator.add, sfs[operator.add], jnp.zeros_like(sumw)) + return SF(multiplicative=multiplicative_sf, additive=additive_sf) diff --git a/src/evermore/util.py b/src/evermore/util.py index b151d70..7736af4 100644 --- a/src/evermore/util.py +++ b/src/evermore/util.py @@ -11,10 +11,7 @@ import jax.tree_util as jtu from jaxtyping import Array, ArrayLike, PyTree -from evermore.custom_types import AddOrMulSFs - __all__ = [ - "initSF", "is_parameter", "sum_leaves", "as1darray", @@ -27,10 +24,6 @@ def __dir__(): return __all__ -def initSF(shape: tuple) -> AddOrMulSFs: - return {operator.add: jnp.zeros(shape), operator.mul: jnp.ones(shape)} - - def is_parameter(leaf: Any) -> bool: from evermore import Parameter diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 277ecc0..5d6ca59 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,12 +1,10 @@ from __future__ import annotations -import operator - import jax.numpy as jnp import pytest import evermore as evm -from evermore.custom_types import _NoValue +from evermore.custom_types import SF, _NoValue from evermore.pdf import Flat, Gauss, Poisson @@ -24,10 +22,9 @@ def test_unconstrained(): u = evm.effect.unconstrained() assert u.constraint(p) == Flat() - assert u.scale_factor(p, jnp.array(1.0)) == { - operator.mul: jnp.array([1.0]), - operator.add: jnp.array([0.0]), - } + assert u.scale_factor(p, jnp.array([1.0])) == SF( + multiplicative=jnp.array([1.0]), additive=jnp.array([0.0]) + ) def test_gauss(): @@ -35,10 +32,9 @@ def test_gauss(): g = evm.effect.gauss(width=jnp.array(1.0)) assert g.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) - assert g.scale_factor(p, jnp.array(1.0)) == { - operator.mul: jnp.array([1.0]), - operator.add: jnp.array([0.0]), - } + assert g.scale_factor(p, jnp.array([1.0])) == SF( + multiplicative=jnp.array([1.0]), additive=jnp.array([0.0]) + ) def test_lnN(): @@ -46,10 +42,9 @@ def test_lnN(): ln = evm.effect.lnN(width=jnp.array([0.9, 1.1])) assert ln.constraint(p) == Gauss(mean=jnp.array(0.0), width=jnp.array(1.0)) - assert ln.scale_factor(p, jnp.array(1.0)) == { - operator.mul: jnp.array([1.0]), - operator.add: jnp.array([0.0]), - } + assert ln.scale_factor(p, jnp.array([1.0])) == SF( + multiplicative=jnp.array([1.0]), additive=jnp.array([0.0]) + ) def test_poisson(): From 07e8805f4a2df0c611f0979d5cf868422b2817bb Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 13:50:57 +0100 Subject: [PATCH 15/22] add mask modifier --- src/evermore/modifier.py | 42 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index f95b11e..d96eabf 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -23,6 +23,7 @@ "Modifier", "compose", "where", + "mask", ] @@ -186,6 +187,47 @@ def _where(true: Array, false: Array) -> Array: return jtu.tree_map(_where, true_sf, false_sf) +class mask(ModifierBase): + """ + Mask a modifier for specific bins. + + The mask is a boolean array (True, False for each bin). + The modifier is only applied to the bins where the mask is True. + + Example: + + .. code-block:: python + + import jax.numpy as jnp + import evermore as evm + + hist = jnp.array([5, 20, 30]) + syst = evm.Parameter(value=0.0) + + norm = syst.lnN(jnp.array([0.9, 1.1])) + mask = jnp.array([True, False, True]) + + modifier = evm.modifier.mask(mask, norm) + + # apply + modifier(hist) + """ + + where: Array = eqx.field(static=True) + modifier: Modifier + + def scale_factor(self, sumw: Array) -> SF: + sf = self.modifier.scale_factor(sumw) + + def _mask(true: Array, false: Array) -> Array: + return jnp.where(self.where, true, false) + + return SF( + multiplicative=_mask(sf.multiplicative, jnp.ones_like(sumw)), + additive=_mask(sf.additive, jnp.zeros_like(sumw)), + ) + + class compose(ModifierBase): """ Composition of multiple modifiers, i.e.: `(f ∘ g ∘ h)(hist) = f(hist) * g(hist) * h(hist)` From fcdcdc70e6480603e26a5b1209dfb45fce6638c0 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 13:56:05 +0100 Subject: [PATCH 16/22] improve doc strings --- src/evermore/modifier.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index d96eabf..284b7db 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -162,15 +162,23 @@ class where(ModifierBase): import evermore as evm hist = jnp.array([5, 20, 30]) - syst = evm.Parameter(value=0.0) + syst = evm.Parameter(value=0.1) norm = syst.lnN(jnp.array([0.9, 1.1])) shape = syst.shape(up=jnp.array([7, 22, 31]), down=jnp.array([4, 16, 27])) + # apply norm if hist < 10, else apply shape modifier = evm.modifier.where(hist < 10, norm, shape) # apply modifier(hist) + # -> Array([ 5.049494, 20.281374, 30.181376], dtype=float32) + + # for comparison: + norm(hist) + # -> Array([ 5.049494, 20.197975, 30.296963], dtype=float32) + shape(hist) + # -> Array([ 5.1593127, 20.281374 , 30.181376 ], dtype=float32) """ condition: Array = eqx.field(static=True) @@ -202,7 +210,7 @@ class mask(ModifierBase): import evermore as evm hist = jnp.array([5, 20, 30]) - syst = evm.Parameter(value=0.0) + syst = evm.Parameter(value=0.1) norm = syst.lnN(jnp.array([0.9, 1.1])) mask = jnp.array([True, False, True]) @@ -211,6 +219,7 @@ class mask(ModifierBase): # apply modifier(hist) + # -> Array([ 5.049494, 20. , 30.296963], dtype=float32) """ where: Array = eqx.field(static=True) From 287cbb36d399f21aea0561d5e7ec4d1abde7abbd Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 14:07:10 +0100 Subject: [PATCH 17/22] improve doc string --- src/evermore/modifier.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index 284b7db..75492ff 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -70,16 +70,19 @@ class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier): .. code-block:: python + import equinox as eqx import jax.numpy as jnp import jax.tree_util as jtu + from jaxtyping import Array + import evermore as evm class clip(evm.ModifierBase): modifier: evm.ModifierBase - min_sf: float - max_sf: float + min_sf: float = eqx.field(static=True) + max_sf: float = eqx.field(static=True) - def scale_factor(self, sumw: jnp.ndarray) -> evm.SF: + def scale_factor(self, sumw: Array) -> evm.SF: sf = self.modifier.scale_factor(sumw) return jtu.tree_map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), sf) From 3efe707e30be78d96879554c15f4c7b95ff394f3 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 14:27:32 +0100 Subject: [PATCH 18/22] add transform modifier --- src/evermore/modifier.py | 41 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index 75492ff..fbef217 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -2,6 +2,7 @@ import abc import operator +from collections.abc import Callable from functools import reduce from typing import TYPE_CHECKING @@ -24,6 +25,7 @@ "compose", "where", "mask", + "transform", ] @@ -82,7 +84,7 @@ class clip(evm.ModifierBase): min_sf: float = eqx.field(static=True) max_sf: float = eqx.field(static=True) - def scale_factor(self, sumw: Array) -> evm.SF: + def scale_factor(self, sumw: Array) -> evm.custrom_types.SF: sf = self.modifier.scale_factor(sumw) return jtu.tree_map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), sf) @@ -240,6 +242,43 @@ def _mask(true: Array, false: Array) -> Array: ) +class transform(ModifierBase): + """ + Transform the scale factors of a modifier. + + The `transform_fn` is a function that is applied to both, multiplicative and additive scale factors. + + Example: + + .. code-block:: python + + import jax.numpy as jnp + import evermore as evm + + hist = jnp.array([5, 20, 30]) + syst = evm.Parameter(value=0.1) + + norm = syst.lnN(jnp.array([0.9, 1.1])) + + transformed_norm = evm.modifier.transform(jnp.sqrt, norm) + + # apply + transformed_norm(hist) + # -> Array([ 5.024686, 20.098743, 30.148115], dtype=float32) + + # for comparison: + norm(hist) + # -> Array([ 5.049494, 20.197975, 30.296963], dtype=float32) + """ + + transform_fn: Callable = eqx.field(static=True) + modifier: Modifier + + def scale_factor(self, sumw: Array) -> SF: + sf = self.modifier.scale_factor(sumw) + return jtu.tree_map(self.transform_fn, sf) + + class compose(ModifierBase): """ Composition of multiple modifiers, i.e.: `(f ∘ g ∘ h)(hist) = f(hist) * g(hist) * h(hist)` From 97fb82f1a40225efe432e83a383cd692dfce04ea Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 14:46:52 +0100 Subject: [PATCH 19/22] improve doc string --- src/evermore/modifier.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index fbef217..76fcfc9 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -93,6 +93,11 @@ def scale_factor(self, sumw: Array) -> evm.custrom_types.SF: modifier = parameter.unconstrained() clipped_modifier = clip(modifier=modifier, min_sf=0.8, max_sf=1.2) + + # this example is trivial, because you can also implement it with `evm.modifier.transform`: + from functools import partial + + clipped_modifier = evm.modifier.transform(partial(jnp.clip, a_min=0.8, a_max=1.2), modifier) """ From 31a30e4499d5f0a331bbe1d0801193c3810f8dd5 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 14:57:29 +0100 Subject: [PATCH 20/22] 'sumw' -> 'hist', 'sumw2' -> 'histw2' --- examples/model.py | 8 +++---- src/evermore/custom_types.py | 4 ++-- src/evermore/effect.py | 42 ++++++++++++++++---------------- src/evermore/modifier.py | 46 ++++++++++++++++++------------------ 4 files changed, 50 insertions(+), 50 deletions(-) diff --git a/examples/model.py b/examples/model.py index e0f4886..75cfce6 100644 --- a/examples/model.py +++ b/examples/model.py @@ -13,7 +13,7 @@ class SPlusBModel(eqx.Module): norm2: evm.Parameter shape1: evm.Parameter - def __init__(self, sumw: dict[str, Array], sumw2: dict[str, Array]) -> None: + def __init__(self, hist: dict[str, Array], histw2: dict[str, Array]) -> None: self.mu = evm.Parameter(value=jnp.array([1.0])) self = evm.parameter.auto_init(self) @@ -64,14 +64,14 @@ def __call__(self, hists: dict) -> dict[str, Array]: }, } -sumw = hists["nominal"] -sumw2 = { +hist = hists["nominal"] +histw2 = { "signal": jnp.array([5]), "bkg1": jnp.array([11]), "bkg2": jnp.array([25]), } -model = SPlusBModel(sumw, sumw2) +model = SPlusBModel(hist, histw2) observation = jnp.array([37]) expectations = model(hists) diff --git a/src/evermore/custom_types.py b/src/evermore/custom_types.py index 49b44db..cda36c0 100644 --- a/src/evermore/custom_types.py +++ b/src/evermore/custom_types.py @@ -40,10 +40,10 @@ def __repr__(self) -> str: class ModifierLike(Protocol): - def scale_factor(self, sumw: Array) -> SF: + def scale_factor(self, hist: Array) -> SF: ... - def __call__(self, sumw: Array) -> Array: + def __call__(self, hist: Array) -> Array: ... def __matmul__(self, other: ModifierLike) -> compose: diff --git a/src/evermore/effect.py b/src/evermore/effect.py index 36f8834..05c7c0f 100644 --- a/src/evermore/effect.py +++ b/src/evermore/effect.py @@ -36,7 +36,7 @@ def constraint(self, parameter: Parameter) -> PDF: ... @abc.abstractmethod - def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: + def scale_factor(self, parameter: Parameter, hist: Array) -> SF: ... @@ -44,9 +44,9 @@ class unconstrained(Effect): def constraint(self, parameter: Parameter) -> PDF: return Flat() - def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: - sf = jnp.broadcast_to(parameter.value, sumw.shape) - return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) + def scale_factor(self, parameter: Parameter, hist: Array) -> SF: + sf = jnp.broadcast_to(parameter.value, hist.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(hist)) DEFAULT_EFFECT = unconstrained() @@ -63,7 +63,7 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: + def scale_factor(self, parameter: Parameter, hist: Array) -> SF: """ Implementation with (inverse) CDFs is defined as follows: @@ -81,8 +81,8 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: return (parameter.value * self.width) + 1 """ - sf = jnp.broadcast_to((parameter.value * self.width) + 1, sumw.shape) - return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) + sf = jnp.broadcast_to((parameter.value * self.width) + 1, hist.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(hist)) class shape(Effect): @@ -97,9 +97,9 @@ def __init__( self.up = up # +1 sigma self.down = down # -1 sigma - def vshift(self, sf: Array, sumw: Array) -> Array: + def vshift(self, sf: Array, hist: Array) -> Array: factor = sf - dx_sum = self.up + self.down - 2 * sumw + dx_sum = self.up + self.down - 2 * hist dx_diff = self.up - self.down # taken from https://github.com/nsmith-/jaxfit/blob/8479cd73e733ba35462287753fab44c0c560037b/src/jaxfit/roofit/combine.py#L173C6-L192 @@ -121,13 +121,13 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: - sf = self.vshift(sf=parameter.value, sumw=sumw) - return SF(multiplicative=jnp.ones_like(sumw), additive=sf) - # shift = self.vshift(sf=sf, sumw=sumw) + def scale_factor(self, parameter: Parameter, hist: Array) -> SF: + sf = self.vshift(sf=parameter.value, hist=hist) + return SF(multiplicative=jnp.ones_like(hist), additive=sf) + # shift = self.vshift(sf=sf, hist=hist) # # handle zeros, see: https://github.com/google/jax/issues/5039 - # x = jnp.where(sumw == 0.0, 1.0, sumw) - # return jnp.where(sumw == 0.0, shift, (x + shift) / x) + # x = jnp.where(hist == 0.0, 1.0, hist) + # return jnp.where(hist == 0.0, shift, (x + shift) / x) class lnN(Effect): @@ -161,7 +161,7 @@ def constraint(self, parameter: Parameter) -> PDF: mean=jnp.zeros_like(parameter.value), width=jnp.ones_like(parameter.value) ) - def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: + def scale_factor(self, parameter: Parameter, hist: Array) -> SF: """ Implementation with (inverse) CDFs is defined as follows: @@ -180,8 +180,8 @@ def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: """ interp = self.interpolate(parameter=parameter) - sf = jnp.broadcast_to(jnp.exp(parameter.value * interp), sumw.shape) - return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) + sf = jnp.broadcast_to(jnp.exp(parameter.value * interp), hist.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(hist)) class poisson(Effect): @@ -194,6 +194,6 @@ def constraint(self, parameter: Parameter) -> PDF: assert parameter.value.shape == self.lamb.shape return Poisson(lamb=self.lamb) - def scale_factor(self, parameter: Parameter, sumw: Array) -> SF: - sf = jnp.broadcast_to(parameter.value + 1, sumw.shape) - return SF(multiplicative=sf, additive=jnp.zeros_like(sumw)) + def scale_factor(self, parameter: Parameter, hist: Array) -> SF: + sf = jnp.broadcast_to(parameter.value + 1, hist.shape) + return SF(multiplicative=sf, additive=jnp.zeros_like(hist)) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index 76fcfc9..65ca27d 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -35,11 +35,11 @@ def __dir__(): class AbstractModifier(eqx.Module): @abc.abstractmethod - def scale_factor(self: ModifierLike, sumw: Array) -> SF: + def scale_factor(self: ModifierLike, hist: Array) -> SF: ... @abc.abstractmethod - def __call__(self: ModifierLike, sumw: Array) -> Array: + def __call__(self: ModifierLike, hist: Array) -> Array: ... @abc.abstractmethod @@ -49,10 +49,10 @@ def __matmul__(self: ModifierLike, other: ModifierLike) -> compose: class ApplyFn(eqx.Module): @jax.named_scope("evm.modifier.ApplyFn") - def __call__(self: ModifierLike, sumw: Array) -> Array: - sf = self.scale_factor(sumw=sumw) + def __call__(self: ModifierLike, hist: Array) -> Array: + sf = self.scale_factor(hist=hist) # apply - return sf.multiplicative * (sumw + sf.additive) + return sf.multiplicative * (hist + sf.additive) class MatMulCompose(eqx.Module): @@ -63,7 +63,7 @@ def __matmul__(self: ModifierLike, other: ModifierLike) -> compose: class ModifierBase(ApplyFn, MatMulCompose, AbstractModifier): """ This serves as a base class for all modifiers. - It automatically implements the __call__ method to apply the scale factors to the sumw array + It automatically implements the __call__ method to apply the scale factors to the hist array and the __matmul__ method to compose two modifiers. Custom modifiers should inherit from this class and implement the scale_factor method. @@ -84,8 +84,8 @@ class clip(evm.ModifierBase): min_sf: float = eqx.field(static=True) max_sf: float = eqx.field(static=True) - def scale_factor(self, sumw: Array) -> evm.custrom_types.SF: - sf = self.modifier.scale_factor(sumw) + def scale_factor(self, hist: Array) -> evm.custrom_types.SF: + sf = self.modifier.scale_factor(hist) return jtu.tree_map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), sf) @@ -154,8 +154,8 @@ def __init__(self, parameter: Parameter, effect: Effect = DEFAULT_EFFECT) -> Non constraint = self.effect.constraint(parameter=self.parameter) self.parameter._set_constraint(constraint, overwrite=False) - def scale_factor(self, sumw: Array) -> SF: - return self.effect.scale_factor(parameter=self.parameter, sumw=sumw) + def scale_factor(self, hist: Array) -> SF: + return self.effect.scale_factor(parameter=self.parameter, hist=hist) class where(ModifierBase): @@ -195,9 +195,9 @@ class where(ModifierBase): modifier_true: Modifier modifier_false: Modifier - def scale_factor(self, sumw: Array) -> SF: - true_sf = self.modifier_true.scale_factor(sumw) - false_sf = self.modifier_false.scale_factor(sumw) + def scale_factor(self, hist: Array) -> SF: + true_sf = self.modifier_true.scale_factor(hist) + false_sf = self.modifier_false.scale_factor(hist) def _where(true: Array, false: Array) -> Array: return jnp.where(self.condition, true, false) @@ -235,15 +235,15 @@ class mask(ModifierBase): where: Array = eqx.field(static=True) modifier: Modifier - def scale_factor(self, sumw: Array) -> SF: - sf = self.modifier.scale_factor(sumw) + def scale_factor(self, hist: Array) -> SF: + sf = self.modifier.scale_factor(hist) def _mask(true: Array, false: Array) -> Array: return jnp.where(self.where, true, false) return SF( - multiplicative=_mask(sf.multiplicative, jnp.ones_like(sumw)), - additive=_mask(sf.additive, jnp.zeros_like(sumw)), + multiplicative=_mask(sf.multiplicative, jnp.ones_like(hist)), + additive=_mask(sf.additive, jnp.zeros_like(hist)), ) @@ -279,8 +279,8 @@ class transform(ModifierBase): transform_fn: Callable = eqx.field(static=True) modifier: Modifier - def scale_factor(self, sumw: Array) -> SF: - sf = self.modifier.scale_factor(sumw) + def scale_factor(self, hist: Array) -> SF: + sf = self.modifier.scale_factor(hist) return jtu.tree_map(self.transform_fn, sf) @@ -349,16 +349,16 @@ def __init__(self, *modifiers: ModifierLike) -> None: def __len__(self) -> int: return len(self.modifiers) - def scale_factor(self, sumw: Array) -> SF: + def scale_factor(self, hist: Array) -> SF: # collect all multiplicative and additive shifts sfs: dict[AddOrMul, list] = {operator.add: [], operator.mul: []} for m in range(len(self)): mod = self.modifiers[m] - _sf = mod.scale_factor(sumw) + _sf = mod.scale_factor(hist) sfs[operator.mul].append(_sf.multiplicative) sfs[operator.add].append(_sf.additive) # calculate the product with for operator.mul and operator.add - multiplicative_sf = reduce(operator.mul, sfs[operator.mul], jnp.ones_like(sumw)) - additive_sf = reduce(operator.add, sfs[operator.add], jnp.zeros_like(sumw)) + multiplicative_sf = reduce(operator.mul, sfs[operator.mul], jnp.ones_like(hist)) + additive_sf = reduce(operator.add, sfs[operator.add], jnp.zeros_like(hist)) return SF(multiplicative=multiplicative_sf, additive=additive_sf) From 6fe7adb9598d3cd88e238dbffa79a58f99c34dd2 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 15:35:48 +0100 Subject: [PATCH 21/22] minor improvements --- src/evermore/parameter.py | 2 +- src/evermore/pdf.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/evermore/parameter.py b/src/evermore/parameter.py index 44dbd19..73471b9 100644 --- a/src/evermore/parameter.py +++ b/src/evermore/parameter.py @@ -58,7 +58,7 @@ def _set_constraint(self, constraint: PDF, overwrite: bool = False) -> PDF: # will always be compatible within the same class (underlying arrays are equal by construction). # This significantly speeds up this check. if self.constraint.__class__ is not constraint.__class__: - msg = f"Parameter constraint '{self.constraint}' is different than the constraint {constraint} to be added." + msg = f"Parameter constraint '{self.constraint}' is different from the new constraint '{constraint}'." raise ValueError(msg) return cast(PDF, self.constraint) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index 3064cc7..ad61a7a 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -43,13 +43,13 @@ def sample(self, key: PRNGKeyArray) -> Array: class Flat(PDF): def logpdf(self, x: Array) -> Array: - return jnp.array([0.0]) + return jnp.zeros_like(x) def pdf(self, x: Array) -> Array: - return jnp.array([1.0]) + return jnp.ones_like(x) def cdf(self, x: Array) -> Array: - return jnp.array([1.0]) + return jnp.ones_like(x) def sample(self, key: PRNGKeyArray) -> Array: # sample parameter from pdf From 578c2370bf814ec683119c0858b808353ee74f60 Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Sat, 9 Mar 2024 16:14:51 +0100 Subject: [PATCH 22/22] minor improvements --- src/evermore/custom_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/evermore/custom_types.py b/src/evermore/custom_types.py index cda36c0..339eba8 100644 --- a/src/evermore/custom_types.py +++ b/src/evermore/custom_types.py @@ -25,7 +25,7 @@ class SF(NamedTuple): class Sentinel: - repr: str + __slots__ = ("repr",) def __init__(self, repr: str) -> None: self.repr = repr