From da11de11a4d87a5660809f2e2b0bbb023b45f0f5 Mon Sep 17 00:00:00 2001 From: Artanisax <105636380+Artanisax@users.noreply.github.com> Date: Fri, 28 Jul 2023 10:42:03 +0800 Subject: [PATCH] dev: add NSGA-III algorithm * create fork and start nsga3 * calculate ideal/extreme points, plane & intercepts * finish association & start niching * finish associate * the last while ... * roughly finished * ready to debug & test * checkout * remote * remote * ok until Niche * fin * NSGA-III Fin! * ex -> evox --------- Co-authored-by: Artanisax <12110524@mail.sustech.edu.cn> --- src/evox/algorithms/mo/__init__.py | 1 + src/evox/algorithms/mo/nsga3.py | 172 +++++++++++++++++++++++ tests/test_multi_objective_algorithms.py | 9 ++ 3 files changed, 182 insertions(+) create mode 100644 src/evox/algorithms/mo/nsga3.py diff --git a/src/evox/algorithms/mo/__init__.py b/src/evox/algorithms/mo/__init__.py index 7733f5d9..b86ae10b 100644 --- a/src/evox/algorithms/mo/__init__.py +++ b/src/evox/algorithms/mo/__init__.py @@ -2,3 +2,4 @@ from .rvea import RVEA from .moead import MOEAD from .ibea import IBEA +from .nsga3 import NSGA3 diff --git a/src/evox/algorithms/mo/nsga3.py b/src/evox/algorithms/mo/nsga3.py new file mode 100644 index 00000000..e7e02def --- /dev/null +++ b/src/evox/algorithms/mo/nsga3.py @@ -0,0 +1,172 @@ +import evox +import jax +import jax.numpy as jnp + +from evox.operators.selection import UniformRandomSelection +from evox.operators.mutation import GaussianMutation, PmMutation +from evox.operators.crossover import UniformCrossover, SimulatedBinaryCrossover +from evox.operators.sampling import UniformSampling +from evox.operators import non_dominated_sort + + +@evox.jit_class +class NSGA3(evox.Algorithm): + """NSGA-III algorithm + + link: https://ieeexplore.ieee.org/document/6600851 + """ + + def __init__( + self, + lb, + ub, + n_objs, + pop_size, + ref=None, + selection=UniformRandomSelection(p=0.5), + mutation=GaussianMutation(), + # mutation=PmMutation(), + crossover=UniformCrossover(), + # crossover=SimulatedBinaryCrossover(), + ): + self.lb = lb + self.ub = ub + self.n_objs = n_objs + self.dim = lb.shape[0] + self.pop_size = pop_size + self.ref = ref if ref else UniformSampling(pop_size, n_objs).random()[0] + + self.selection = selection + self.mutation = mutation + self.crossover = crossover + + def setup(self, key): + key, subkey = jax.random.split(key) + population = ( + jax.random.uniform(subkey, shape=(self.pop_size, self.dim)) + * (self.ub - self.lb) + + self.lb + ) + return evox.State( + population=population, + fitness=jnp.zeros((self.pop_size, self.n_objs)), + next_generation=population, + is_init=True + ) + + def ask(self, state): + return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state) + + def tell(self, state, fitness): + return jax.lax.cond( + state.is_init, self._tell_init, self._tell_normal, state, fitness + ) + + def _ask_init(self, state): + return state.population, state + + def _ask_normal(self, state): + mutated, state = self.selection(state, state.population) + mutated, state = self.mutation(state, mutated) + + crossovered, state = self.selection(state, state.population) + crossovered, state = self.crossover(state, crossovered) + + next_generation = jnp.clip( + jnp.concatenate([mutated, crossovered], axis=0), self.lb, self.ub + ) + return next_generation, state.update(next_generation=next_generation) + + def _tell_init(self, state, fitness): + state = state.update(fitness=fitness, is_init=False) + return state + + def _tell_normal(self, state, fitness): + merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0) + merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0) + + rank = non_dominated_sort(merged_fitness) + order = jnp.argsort(rank) + rank = rank[order] + ranked_pop = merged_pop[order] + ranked_fitness = merged_fitness[order] + last_rank = rank[self.pop_size] + ranked_fitness = jnp.where(jnp.repeat((rank <= last_rank)[:, None], self.n_objs, axis=1), ranked_fitness, jnp.nan) + + # Normalize + ideal = jnp.nanmin(ranked_fitness, axis=0) + offset_fitness = ranked_fitness - ideal + weight = jnp.eye(self.n_objs, self.n_objs) + 1e-6 + weighted = jnp.repeat(offset_fitness, self.n_objs, axis=0).reshape(len(offset_fitness), self.n_objs, self.n_objs) / weight + asf = jnp.nanmax(weighted, axis=2) + ex_idx =jnp.argmin(asf, axis=0) + extreme = offset_fitness[ex_idx] + + def extreme_point(val): + extreme = val[0] + plane = jnp.linalg.solve(extreme, jnp.ones(self.n_objs)) + intercept = 1/ plane + return intercept + + def worst_point(val): + return jnp.nanmax(ranked_fitness, axis=0) + + nadir_point = jax.lax.cond(jnp.linalg.matrix_rank(extreme) == self.n_objs, + extreme_point, worst_point, + (extreme, offset_fitness)) + normalized_fitness = offset_fitness / nadir_point + + # Associate + def perpendicular_distance(x, y): + y_norm = jnp.linalg.norm(y, axis=1) + proj_len = x @ y.T / y_norm + unit_vec = y / y_norm[:, None] + proj_vec = jnp.reshape(proj_len, (proj_len.size, 1)) * jnp.tile(unit_vec, (len(x), 1)) + prep_vec = jnp.repeat(x, len(y), axis=0) - proj_vec + dist = jnp.reshape(jnp.linalg.norm(prep_vec, axis=1), (len(x), len(y))) + return dist + + dist = perpendicular_distance(ranked_fitness, self.ref) + pi = jnp.nanargmin(dist, axis=1) + d = dist[jnp.arange(len(normalized_fitness)), pi] + + # Niche + def niche_loop(val): + def nope(val): + idx, i, rho, j = val + rho = rho.at[j].set(self.pop_size) + return idx, i, rho, j + + def have(val): + def zero(val): + idx, i, rho, j = val + idx = idx.at[i].set(jnp.nanargmin(jnp.where(pi == j, d, jnp.nan))) + rho = rho.at[j].add(1) + return idx, i+1, rho, j + + def already(val): + idx, i, rho, j = val + key = jax.random.PRNGKey(i * j) + temp = jax.random.randint(key, (1, len(ranked_pop)), 0, self.pop_size) + temp = temp + (pi == j) * self.pop_size + idx = idx.at[i].set(jnp.argmax(temp)) + rho = rho.at[j].add(1) + return idx, i+1, rho, j + + return jax.lax.cond(rho[val[3]], already, zero, val) + + idx, i, rho = val + j = jnp.argmin(rho) + idx, i, rho, j = jax.lax.cond(jnp.sum(pi == j), have, nope, (idx, i, rho, j)) + return idx, i, rho + + survivor_idx = jnp.arange(self.pop_size) + rho = jnp.bincount(jnp.where(rank < last_rank, pi, len(self.ref)), length=len(self.ref)) + pi = jnp.where(rank == last_rank, pi, -1) + d = jnp.where(rank == last_rank, d, jnp.nan) + survivor_idx, _, _ = jax.lax.while_loop(lambda val: val[1] < self.pop_size, + niche_loop, + (survivor_idx, jnp.sum(rho), rho)) + + state = state.update(population=ranked_pop[survivor_idx], fitness=ranked_fitness[survivor_idx]) + return state diff --git a/tests/test_multi_objective_algorithms.py b/tests/test_multi_objective_algorithms.py index 73020ca4..ee4147f0 100644 --- a/tests/test_multi_objective_algorithms.py +++ b/tests/test_multi_objective_algorithms.py @@ -64,3 +64,12 @@ def test_rvea(): pop_size=100, ) run_moea(algorithm) + +def test_nsga3(): + algorithm = algorithms.NSGA3( + lb=jnp.full(shape=(3,), fill_value=0), + ub=jnp.full(shape=(3,), fill_value=1), + n_objs=3, + pop_size=100, + ) + run_moea(algorithm)