Skip to content

Commit

Permalink
fix toy generation
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Nov 7, 2023
1 parent 07eab11 commit dd07c90
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
11 changes: 6 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ ci:
autofix_commit_msg: "style: pre-commit fixes"

repos:
- repo: https://github.com/psf/black-pre-commit-mirror
rev: "23.7.0"
hooks:
- id: black-jupyter

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v4.4.0"
hooks:
Expand Down Expand Up @@ -37,6 +32,12 @@ repos:
- id: ruff
args: ["--fix", "--show-fixes"]

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.291
hooks:
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.5.1"
hooks:
Expand Down
2 changes: 1 addition & 1 deletion src/dilax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
__contact__ = "https://github.com/pfackeldey/dilax"
__license__ = "BSD-3-Clause"
__status__ = "Development"
__version__ = "0.1.2"
__version__ = "0.1.3"
12 changes: 4 additions & 8 deletions src/dilax/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def logpdf(self, *args, **kwargs) -> jax.Array:

def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array:
if values is _NoValue:
values = {}
values = self.model.parameter_values
model = self.model.update(values=values)
res = model.evaluate()
nll = (
Expand All @@ -58,8 +58,6 @@ def __init__(self, model: Model, observation: jax.Array) -> None:

def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array:
if values is _NoValue:
values = {}
if not values:
values = self.model.parameter_values
if TYPE_CHECKING:
values = cast(dict[str, jax.Array], values)
Expand All @@ -77,9 +75,9 @@ class CovMatrix(Hessian):

def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array:
if values is _NoValue:
values = {}
values = self.model.parameter_values
hessian = super().__call__(values=values)
return jnp.linalg.inv(-hessian)
return jnp.linalg.inv(hessian)


class SampleToy(BaseModule):
Expand All @@ -99,13 +97,11 @@ def __call__(
key: jax.Array | Sentinel = _NoValue,
) -> dict[str, jax.Array]:
if values is _NoValue:
values = {}
values = self.model.parameter_values
if key is _NoValue:
key = jax.random.PRNGKey(1234)
if TYPE_CHECKING:
key = cast(jax.Array, key)
if not values:
values = self.model.parameter_values
cov = self.CovMatrix(values=values)
_values, tree_def = jax.tree_util.tree_flatten(
self.model.update(values=values).parameter_values
Expand Down

0 comments on commit dd07c90

Please sign in to comment.