Skip to content

Commit

Permalink
fix: eval monitor in mo setting
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Apr 25, 2024
1 parent bc6c549 commit 920a836
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
16 changes: 9 additions & 7 deletions src/evox/monitors/eval_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(
self.topk_solutions = None
self.pf_solutions = None
self.pf_fitness = None
self.latest_solution = None
self.latest_fitness = None
self.eval_count = 0
self.opt_direction = 1 # default to min, so no transformation is needed

Expand Down Expand Up @@ -133,9 +135,6 @@ def record_fit_single_obj(self, cand_sol, cand_fit, fitness):
self.topk_fitness = self.topk_fitness[topk_rank]

def record_fit_multi_obj(self, cand_sol, fitness):
if cand_fit is None:
cand_fit = fitness

if self.full_sol_history:
self.solution_history.append(cand_sol)

Expand All @@ -159,12 +158,15 @@ def record_fit_multi_obj(self, cand_sol, fitness):
pf = rank == 0
self.pf_fitness = self.pf_fitness[pf]
self.pf_solutions = self.pf_solutions[pf]
else:
self.pf_fitness = fitness
self.pf_solutions = cand_sol

self.latest_fitness = fitness
self.latest_solution = cand_sol

def get_latest_fitness(self):
return self.opt_direction * self.fitness_history[-1]
return self.opt_direction * self.latest_fitness

def get_latest_solution(self):
return self.latest_solution

def get_pf_fitness(self):
return self.opt_direction * self.pf_fitness
Expand Down
38 changes: 36 additions & 2 deletions tests/test_monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,35 @@ def test_eval_monitor_with_so(full_fit_history, full_sol_history, topk):
assert (monitor.get_topk_solutions() == pop2[-topk:][::-1]).all()


@pytest.mark.parametrize(
"full_fit_history,full_sol_history",
[
(False, False),
(False, True),
(True, False),
(True, True),
],
)
def test_eval_monitor_with_mo(full_fit_history, full_sol_history):
monitor = EvalMonitor(
full_fit_history=full_fit_history, full_sol_history=full_sol_history
)
monitor.set_opt_direction = 1

pop1 = jnp.arange(15).reshape((3, 5))
fitness1 = jnp.arange(6).reshape(3, 2)

monitor.post_eval(None, pop1, None, fitness1)
assert (monitor.get_latest_fitness() == fitness1).all()
assert (monitor.get_latest_solution() == pop1).all()

pop2 = -jnp.arange(15).reshape((3, 5))
fitness2 = -jnp.arange(6).reshape(3, 2)
monitor.post_eval(None, pop2, None, fitness2)
assert (monitor.get_latest_fitness() == fitness2).all()
assert (monitor.get_latest_solution() == pop2).all()


@pytest.mark.parametrize("fitness_only", [True, False])
def test_pop_monitor(fitness_only):
monitor = PopMonitor(fitness_only=fitness_only)
Expand All @@ -94,6 +123,11 @@ def test_pop_monitor(fitness_only):
key = jax.random.PRNGKey(0)
state = workflow.init(key)
state = workflow.step(state)
assert (monitor.get_latest_fitness() == state.get_child_state("algorithm").fitness).all()
assert (
monitor.get_latest_fitness() == state.get_child_state("algorithm").fitness
).all()
if not fitness_only:
assert (monitor.get_latest_population() == state.get_child_state("algorithm").population).all()
assert (
monitor.get_latest_population()
== state.get_child_state("algorithm").population
).all()

0 comments on commit 920a836

Please sign in to comment.