-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of github.com:EMI-Group/evox
- Loading branch information
Showing
3 changed files
with
182 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
from .rvea import RVEA | ||
from .moead import MOEAD | ||
from .ibea import IBEA | ||
from .nsga3 import NSGA3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import evox | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from evox.operators.selection import UniformRandomSelection | ||
from evox.operators.mutation import GaussianMutation, PmMutation | ||
from evox.operators.crossover import UniformCrossover, SimulatedBinaryCrossover | ||
from evox.operators.sampling import UniformSampling | ||
from evox.operators import non_dominated_sort | ||
|
||
|
||
@evox.jit_class | ||
class NSGA3(evox.Algorithm): | ||
"""NSGA-III algorithm | ||
link: https://ieeexplore.ieee.org/document/6600851 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
lb, | ||
ub, | ||
n_objs, | ||
pop_size, | ||
ref=None, | ||
selection=UniformRandomSelection(p=0.5), | ||
mutation=GaussianMutation(), | ||
# mutation=PmMutation(), | ||
crossover=UniformCrossover(), | ||
# crossover=SimulatedBinaryCrossover(), | ||
): | ||
self.lb = lb | ||
self.ub = ub | ||
self.n_objs = n_objs | ||
self.dim = lb.shape[0] | ||
self.pop_size = pop_size | ||
self.ref = ref if ref else UniformSampling(pop_size, n_objs).random()[0] | ||
|
||
self.selection = selection | ||
self.mutation = mutation | ||
self.crossover = crossover | ||
|
||
def setup(self, key): | ||
key, subkey = jax.random.split(key) | ||
population = ( | ||
jax.random.uniform(subkey, shape=(self.pop_size, self.dim)) | ||
* (self.ub - self.lb) | ||
+ self.lb | ||
) | ||
return evox.State( | ||
population=population, | ||
fitness=jnp.zeros((self.pop_size, self.n_objs)), | ||
next_generation=population, | ||
is_init=True | ||
) | ||
|
||
def ask(self, state): | ||
return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state) | ||
|
||
def tell(self, state, fitness): | ||
return jax.lax.cond( | ||
state.is_init, self._tell_init, self._tell_normal, state, fitness | ||
) | ||
|
||
def _ask_init(self, state): | ||
return state.population, state | ||
|
||
def _ask_normal(self, state): | ||
mutated, state = self.selection(state, state.population) | ||
mutated, state = self.mutation(state, mutated) | ||
|
||
crossovered, state = self.selection(state, state.population) | ||
crossovered, state = self.crossover(state, crossovered) | ||
|
||
next_generation = jnp.clip( | ||
jnp.concatenate([mutated, crossovered], axis=0), self.lb, self.ub | ||
) | ||
return next_generation, state.update(next_generation=next_generation) | ||
|
||
def _tell_init(self, state, fitness): | ||
state = state.update(fitness=fitness, is_init=False) | ||
return state | ||
|
||
def _tell_normal(self, state, fitness): | ||
merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0) | ||
merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0) | ||
|
||
rank = non_dominated_sort(merged_fitness) | ||
order = jnp.argsort(rank) | ||
rank = rank[order] | ||
ranked_pop = merged_pop[order] | ||
ranked_fitness = merged_fitness[order] | ||
last_rank = rank[self.pop_size] | ||
ranked_fitness = jnp.where(jnp.repeat((rank <= last_rank)[:, None], self.n_objs, axis=1), ranked_fitness, jnp.nan) | ||
|
||
# Normalize | ||
ideal = jnp.nanmin(ranked_fitness, axis=0) | ||
offset_fitness = ranked_fitness - ideal | ||
weight = jnp.eye(self.n_objs, self.n_objs) + 1e-6 | ||
weighted = jnp.repeat(offset_fitness, self.n_objs, axis=0).reshape(len(offset_fitness), self.n_objs, self.n_objs) / weight | ||
asf = jnp.nanmax(weighted, axis=2) | ||
ex_idx =jnp.argmin(asf, axis=0) | ||
extreme = offset_fitness[ex_idx] | ||
|
||
def extreme_point(val): | ||
extreme = val[0] | ||
plane = jnp.linalg.solve(extreme, jnp.ones(self.n_objs)) | ||
intercept = 1/ plane | ||
return intercept | ||
|
||
def worst_point(val): | ||
return jnp.nanmax(ranked_fitness, axis=0) | ||
|
||
nadir_point = jax.lax.cond(jnp.linalg.matrix_rank(extreme) == self.n_objs, | ||
extreme_point, worst_point, | ||
(extreme, offset_fitness)) | ||
normalized_fitness = offset_fitness / nadir_point | ||
|
||
# Associate | ||
def perpendicular_distance(x, y): | ||
y_norm = jnp.linalg.norm(y, axis=1) | ||
proj_len = x @ y.T / y_norm | ||
unit_vec = y / y_norm[:, None] | ||
proj_vec = jnp.reshape(proj_len, (proj_len.size, 1)) * jnp.tile(unit_vec, (len(x), 1)) | ||
prep_vec = jnp.repeat(x, len(y), axis=0) - proj_vec | ||
dist = jnp.reshape(jnp.linalg.norm(prep_vec, axis=1), (len(x), len(y))) | ||
return dist | ||
|
||
dist = perpendicular_distance(ranked_fitness, self.ref) | ||
pi = jnp.nanargmin(dist, axis=1) | ||
d = dist[jnp.arange(len(normalized_fitness)), pi] | ||
|
||
# Niche | ||
def niche_loop(val): | ||
def nope(val): | ||
idx, i, rho, j = val | ||
rho = rho.at[j].set(self.pop_size) | ||
return idx, i, rho, j | ||
|
||
def have(val): | ||
def zero(val): | ||
idx, i, rho, j = val | ||
idx = idx.at[i].set(jnp.nanargmin(jnp.where(pi == j, d, jnp.nan))) | ||
rho = rho.at[j].add(1) | ||
return idx, i+1, rho, j | ||
|
||
def already(val): | ||
idx, i, rho, j = val | ||
key = jax.random.PRNGKey(i * j) | ||
temp = jax.random.randint(key, (1, len(ranked_pop)), 0, self.pop_size) | ||
temp = temp + (pi == j) * self.pop_size | ||
idx = idx.at[i].set(jnp.argmax(temp)) | ||
rho = rho.at[j].add(1) | ||
return idx, i+1, rho, j | ||
|
||
return jax.lax.cond(rho[val[3]], already, zero, val) | ||
|
||
idx, i, rho = val | ||
j = jnp.argmin(rho) | ||
idx, i, rho, j = jax.lax.cond(jnp.sum(pi == j), have, nope, (idx, i, rho, j)) | ||
return idx, i, rho | ||
|
||
survivor_idx = jnp.arange(self.pop_size) | ||
rho = jnp.bincount(jnp.where(rank < last_rank, pi, len(self.ref)), length=len(self.ref)) | ||
pi = jnp.where(rank == last_rank, pi, -1) | ||
d = jnp.where(rank == last_rank, d, jnp.nan) | ||
survivor_idx, _, _ = jax.lax.while_loop(lambda val: val[1] < self.pop_size, | ||
niche_loop, | ||
(survivor_idx, jnp.sum(rho), rho)) | ||
|
||
state = state.update(population=ranked_pop[survivor_idx], fitness=ranked_fitness[survivor_idx]) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters