Skip to content

Commit

Permalink
dev: add NSGA-III algorithm
Browse files Browse the repository at this point in the history
* create fork and start nsga3

* calculate ideal/extreme points, plane & intercepts

* finish association & start niching

* finish associate

* the last while ...

* roughly finished

* ready to debug & test

* checkout

* remote

* remote

* ok until Niche

* fin

* NSGA-III Fin!

* ex -> evox

---------

Co-authored-by: Artanisax <[email protected]>
  • Loading branch information
Artanisax and Artanisax authored Jul 28, 2023
1 parent 6359564 commit da11de1
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/evox/algorithms/mo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .rvea import RVEA
from .moead import MOEAD
from .ibea import IBEA
from .nsga3 import NSGA3
172 changes: 172 additions & 0 deletions src/evox/algorithms/mo/nsga3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import evox
import jax
import jax.numpy as jnp

from evox.operators.selection import UniformRandomSelection
from evox.operators.mutation import GaussianMutation, PmMutation
from evox.operators.crossover import UniformCrossover, SimulatedBinaryCrossover
from evox.operators.sampling import UniformSampling
from evox.operators import non_dominated_sort


@evox.jit_class
class NSGA3(evox.Algorithm):
"""NSGA-III algorithm
link: https://ieeexplore.ieee.org/document/6600851
"""

def __init__(
self,
lb,
ub,
n_objs,
pop_size,
ref=None,
selection=UniformRandomSelection(p=0.5),
mutation=GaussianMutation(),
# mutation=PmMutation(),
crossover=UniformCrossover(),
# crossover=SimulatedBinaryCrossover(),
):
self.lb = lb
self.ub = ub
self.n_objs = n_objs
self.dim = lb.shape[0]
self.pop_size = pop_size
self.ref = ref if ref else UniformSampling(pop_size, n_objs).random()[0]

self.selection = selection
self.mutation = mutation
self.crossover = crossover

def setup(self, key):
key, subkey = jax.random.split(key)
population = (
jax.random.uniform(subkey, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
return evox.State(
population=population,
fitness=jnp.zeros((self.pop_size, self.n_objs)),
next_generation=population,
is_init=True
)

def ask(self, state):
return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state)

def tell(self, state, fitness):
return jax.lax.cond(
state.is_init, self._tell_init, self._tell_normal, state, fitness
)

def _ask_init(self, state):
return state.population, state

def _ask_normal(self, state):
mutated, state = self.selection(state, state.population)
mutated, state = self.mutation(state, mutated)

crossovered, state = self.selection(state, state.population)
crossovered, state = self.crossover(state, crossovered)

next_generation = jnp.clip(
jnp.concatenate([mutated, crossovered], axis=0), self.lb, self.ub
)
return next_generation, state.update(next_generation=next_generation)

def _tell_init(self, state, fitness):
state = state.update(fitness=fitness, is_init=False)
return state

def _tell_normal(self, state, fitness):
merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0)
merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0)

rank = non_dominated_sort(merged_fitness)
order = jnp.argsort(rank)
rank = rank[order]
ranked_pop = merged_pop[order]
ranked_fitness = merged_fitness[order]
last_rank = rank[self.pop_size]
ranked_fitness = jnp.where(jnp.repeat((rank <= last_rank)[:, None], self.n_objs, axis=1), ranked_fitness, jnp.nan)

# Normalize
ideal = jnp.nanmin(ranked_fitness, axis=0)
offset_fitness = ranked_fitness - ideal
weight = jnp.eye(self.n_objs, self.n_objs) + 1e-6
weighted = jnp.repeat(offset_fitness, self.n_objs, axis=0).reshape(len(offset_fitness), self.n_objs, self.n_objs) / weight
asf = jnp.nanmax(weighted, axis=2)
ex_idx =jnp.argmin(asf, axis=0)
extreme = offset_fitness[ex_idx]

def extreme_point(val):
extreme = val[0]
plane = jnp.linalg.solve(extreme, jnp.ones(self.n_objs))
intercept = 1/ plane
return intercept

def worst_point(val):
return jnp.nanmax(ranked_fitness, axis=0)

nadir_point = jax.lax.cond(jnp.linalg.matrix_rank(extreme) == self.n_objs,
extreme_point, worst_point,
(extreme, offset_fitness))
normalized_fitness = offset_fitness / nadir_point

# Associate
def perpendicular_distance(x, y):
y_norm = jnp.linalg.norm(y, axis=1)
proj_len = x @ y.T / y_norm
unit_vec = y / y_norm[:, None]
proj_vec = jnp.reshape(proj_len, (proj_len.size, 1)) * jnp.tile(unit_vec, (len(x), 1))
prep_vec = jnp.repeat(x, len(y), axis=0) - proj_vec
dist = jnp.reshape(jnp.linalg.norm(prep_vec, axis=1), (len(x), len(y)))
return dist

dist = perpendicular_distance(ranked_fitness, self.ref)
pi = jnp.nanargmin(dist, axis=1)
d = dist[jnp.arange(len(normalized_fitness)), pi]

# Niche
def niche_loop(val):
def nope(val):
idx, i, rho, j = val
rho = rho.at[j].set(self.pop_size)
return idx, i, rho, j

def have(val):
def zero(val):
idx, i, rho, j = val
idx = idx.at[i].set(jnp.nanargmin(jnp.where(pi == j, d, jnp.nan)))
rho = rho.at[j].add(1)
return idx, i+1, rho, j

def already(val):
idx, i, rho, j = val
key = jax.random.PRNGKey(i * j)
temp = jax.random.randint(key, (1, len(ranked_pop)), 0, self.pop_size)
temp = temp + (pi == j) * self.pop_size
idx = idx.at[i].set(jnp.argmax(temp))
rho = rho.at[j].add(1)
return idx, i+1, rho, j

return jax.lax.cond(rho[val[3]], already, zero, val)

idx, i, rho = val
j = jnp.argmin(rho)
idx, i, rho, j = jax.lax.cond(jnp.sum(pi == j), have, nope, (idx, i, rho, j))
return idx, i, rho

survivor_idx = jnp.arange(self.pop_size)
rho = jnp.bincount(jnp.where(rank < last_rank, pi, len(self.ref)), length=len(self.ref))
pi = jnp.where(rank == last_rank, pi, -1)
d = jnp.where(rank == last_rank, d, jnp.nan)
survivor_idx, _, _ = jax.lax.while_loop(lambda val: val[1] < self.pop_size,
niche_loop,
(survivor_idx, jnp.sum(rho), rho))

state = state.update(population=ranked_pop[survivor_idx], fitness=ranked_fitness[survivor_idx])
return state
9 changes: 9 additions & 0 deletions tests/test_multi_objective_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,12 @@ def test_rvea():
pop_size=100,
)
run_moea(algorithm)

def test_nsga3():
algorithm = algorithms.NSGA3(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
n_objs=3,
pop_size=100,
)
run_moea(algorithm)

0 comments on commit da11de1

Please sign in to comment.