From dd07c90b9c57cd3892823019a7fecd8e6ea531fc Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Tue, 7 Nov 2023 16:25:05 +0100 Subject: [PATCH] fix toy generation --- .pre-commit-config.yaml | 11 ++++++----- src/dilax/__init__.py | 2 +- src/dilax/likelihood.py | 12 ++++-------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e1fd96..051bebf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,11 +3,6 @@ ci: autofix_commit_msg: "style: pre-commit fixes" repos: - - repo: https://github.com/psf/black-pre-commit-mirror - rev: "23.7.0" - hooks: - - id: black-jupyter - - repo: https://github.com/pre-commit/pre-commit-hooks rev: "v4.4.0" hooks: @@ -37,6 +32,12 @@ repos: - id: ruff args: ["--fix", "--show-fixes"] + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.0.291 + hooks: + - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-mypy rev: "v1.5.1" hooks: diff --git a/src/dilax/__init__.py b/src/dilax/__init__.py index 22b3f81..10ec66b 100644 --- a/src/dilax/__init__.py +++ b/src/dilax/__init__.py @@ -10,4 +10,4 @@ __contact__ = "https://github.com/pfackeldey/dilax" __license__ = "BSD-3-Clause" __status__ = "Development" -__version__ = "0.1.2" +__version__ = "0.1.3" diff --git a/src/dilax/likelihood.py b/src/dilax/likelihood.py index 9e4a888..2cee7ac 100644 --- a/src/dilax/likelihood.py +++ b/src/dilax/likelihood.py @@ -33,7 +33,7 @@ def logpdf(self, *args, **kwargs) -> jax.Array: def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array: if values is _NoValue: - values = {} + values = self.model.parameter_values model = self.model.update(values=values) res = model.evaluate() nll = ( @@ -58,8 +58,6 @@ def __init__(self, model: Model, observation: jax.Array) -> None: def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array: if values is _NoValue: - values = {} - if not values: values = self.model.parameter_values if TYPE_CHECKING: values = cast(dict[str, jax.Array], values) @@ -77,9 +75,9 @@ class CovMatrix(Hessian): def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array: if values is _NoValue: - values = {} + values = self.model.parameter_values hessian = super().__call__(values=values) - return jnp.linalg.inv(-hessian) + return jnp.linalg.inv(hessian) class SampleToy(BaseModule): @@ -99,13 +97,11 @@ def __call__( key: jax.Array | Sentinel = _NoValue, ) -> dict[str, jax.Array]: if values is _NoValue: - values = {} + values = self.model.parameter_values if key is _NoValue: key = jax.random.PRNGKey(1234) if TYPE_CHECKING: key = cast(jax.Array, key) - if not values: - values = self.model.parameter_values cov = self.CovMatrix(values=values) _values, tree_def = jax.tree_util.tree_flatten( self.model.update(values=values).parameter_values