Skip to content

Commit

Permalink
dev: check if both algorithm and problem are dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Oct 23, 2024
1 parent 1bef2d8 commit 7ed05a3
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/evox/workflows/std_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ class StdWorkflow(Workflow):

# inner
_step: Callable[[State], State] = pytree_field(static=True, init=False)
_parallel_step: Callable[[State], State] = pytree_field(static=True, init=False)
_registered_hooks: dict = pytree_field(static=True, init=False)
_pmap_axis_name: str = pytree_field(static=True, init=False)
_opt_direction_mask: jnp.array = pytree_field(init=False)
Expand Down Expand Up @@ -200,7 +199,7 @@ def _step(self, state):

if self.jit_step:
# the first argument is self, which should be static
if dataclasses.is_dataclass(self.algorithm):
if dataclasses.is_dataclass(self.algorithm) and dataclasses.is_dataclass(self.problem):
_step = jax.jit(_step)
else:
_step = jax.jit(_step, static_argnums=(0,))
Expand Down

0 comments on commit 7ed05a3

Please sign in to comment.