Skip to content

Commit

Permalink
test: update test to use new api
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Sep 9, 2024
1 parent 7e2476b commit bcd30f6
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 16 deletions.
15 changes: 5 additions & 10 deletions tests/test_containers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import pytest
from evox import algorithms, workflows, problems, Stateful
from evox import algorithms, workflows, problems, Stateful, use_state
from evox.monitors import EvalMonitor


Expand Down Expand Up @@ -29,7 +29,7 @@ def test_clustered_cma_es():
for i in range(200):
state = workflow.step(state)

min_fitness = monitor.get_best_fitness()
min_fitness, _state = use_state(monitor.get_best_fitness)(state)
assert min_fitness < 2


Expand Down Expand Up @@ -60,10 +60,8 @@ def test_vectorized_coevolution(random_subpop):

for i in range(200):
state = workflow.step(state)

monitor.close()

min_fitness = monitor.get_best_fitness()
min_fitness, _state = use_state(monitor.get_best_fitness)(state)
assert min_fitness < 0.5


Expand Down Expand Up @@ -94,10 +92,8 @@ def test_coevolution(random_subpop):
state = workflow.init(key)
for i in range(400):
state = workflow.step(state)

monitor.close()

min_fitness = monitor.get_best_fitness()
min_fitness, _state = use_state(monitor.get_best_fitness)(state)
assert min_fitness < 0.5


Expand Down Expand Up @@ -129,6 +125,5 @@ def test_random_mask_cso():
for i in range(10):
state = workflow.step(state)

min_fitness = monitor.get_best_fitness()
print(min_fitness)
min_fitness, _state = use_state(monitor.get_best_fitness)(state)
assert abs(min_fitness - 19.6) < 0.1
5 changes: 3 additions & 2 deletions tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def __call__(self, x):
for i in range(2):
state = workflow.step(state)

monitor.flush()
min_fitness, state = use_state(monitor.get_best_fitness)(state)
min_fitness, _state = use_state(monitor.get_best_fitness)(state)
fit_history, _state = use_state(monitor.get_fitness_history)(state)
print(fit_history)
# gym is deterministic, so the result should always be the same
assert min_fitness == 40.0

Expand Down
6 changes: 3 additions & 3 deletions tests/test_neuroevolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from flax import linen as nn

from evox import algorithms, problems, workflows
from evox import algorithms, problems, workflows, use_state
from evox.monitors import EvalMonitor
from evox.utils import TreeAndVector, rank_based_fitness

Expand Down Expand Up @@ -67,5 +67,5 @@ def loss_func(weight, data):
for i in range(3):
state = workflow.step(state)

best_fitness = monitor.get_best_fitness(state).item()
assert math.isclose(best_fitness, 0.07662, abs_tol=0.01)
best_fitness, _state = use_state(monitor.get_best_fitness)(state)
assert math.isclose(best_fitness.item(), 0.07662, abs_tol=0.01)
1 change: 0 additions & 1 deletion tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def run_std_workflow_with_jit_problem():
# init the workflow
key = jax.random.PRNGKey(42)
state = workflow.init(key)
state = workflow.enable_multi_devices(state)

# run the workflow for 100 steps
for i in range(100):
Expand Down

0 comments on commit bcd30f6

Please sign in to comment.