diff --git a/docs/source/api/algorithms/mo/bce_ibea.rst b/docs/source/api/algorithms/mo/bce_ibea.rst new file mode 100644 index 00000000..40e8de90 --- /dev/null +++ b/docs/source/api/algorithms/mo/bce_ibea.rst @@ -0,0 +1,6 @@ +======= +BCEIBEA +======= + +.. autoclass:: evox.algorithms.BCEIBEA + :members: \ No newline at end of file diff --git a/docs/source/api/algorithms/mo/bige.rst b/docs/source/api/algorithms/mo/bige.rst new file mode 100644 index 00000000..9502e1b6 --- /dev/null +++ b/docs/source/api/algorithms/mo/bige.rst @@ -0,0 +1,6 @@ +==== +BiGE +==== + +.. autoclass:: evox.algorithms.BiGE + :members: diff --git a/docs/source/api/algorithms/mo/eagmoead.rst b/docs/source/api/algorithms/mo/eagmoead.rst new file mode 100644 index 00000000..71128a1d --- /dev/null +++ b/docs/source/api/algorithms/mo/eagmoead.rst @@ -0,0 +1,6 @@ +========== +EAG-MOEA/D +========== + +.. autoclass:: evox.algorithms.EAGMOEAD + :members: diff --git a/docs/source/api/algorithms/mo/gde3 b/docs/source/api/algorithms/mo/gde3 new file mode 100644 index 00000000..3334ba9f --- /dev/null +++ b/docs/source/api/algorithms/mo/gde3 @@ -0,0 +1,6 @@ +==== +GDE3 +==== + +.. autoclass:: evox.algorithms.GDE3 + :members: diff --git a/docs/source/api/algorithms/mo/hype.rst b/docs/source/api/algorithms/mo/hype.rst new file mode 100644 index 00000000..e1aa95cd --- /dev/null +++ b/docs/source/api/algorithms/mo/hype.rst @@ -0,0 +1,6 @@ +===== +HypE +===== + +.. autoclass:: evox.algorithms.HypE + :members: diff --git a/docs/source/api/algorithms/mo/index.rst b/docs/source/api/algorithms/mo/index.rst index 88a5f25b..159902f6 100644 --- a/docs/source/api/algorithms/mo/index.rst +++ b/docs/source/api/algorithms/mo/index.rst @@ -5,8 +5,19 @@ Multi-objective .. toctree:: :maxdepth: 1 + bce_ibea + bige + eagmoead + gde3 + hype ibea + knea moead + moeaddra + moeadm2m nsga2 nsga3 - rvea \ No newline at end of file + rvea + spea2 + sra + tdea \ No newline at end of file diff --git a/docs/source/api/algorithms/mo/knea.rst b/docs/source/api/algorithms/mo/knea.rst new file mode 100644 index 00000000..870f71f6 --- /dev/null +++ b/docs/source/api/algorithms/mo/knea.rst @@ -0,0 +1,6 @@ +==== +KnEA +==== + +.. autoclass:: evox.algorithms.KnEA + :members: diff --git a/docs/source/api/algorithms/mo/moeaddra.rst b/docs/source/api/algorithms/mo/moeaddra.rst new file mode 100644 index 00000000..b57cf1be --- /dev/null +++ b/docs/source/api/algorithms/mo/moeaddra.rst @@ -0,0 +1,6 @@ +========== +MOEA/D-DRA +========== + +.. autoclass:: evox.algorithms.MOEADDRA + :members: diff --git a/docs/source/api/algorithms/mo/moeadm2m.rst b/docs/source/api/algorithms/mo/moeadm2m.rst new file mode 100644 index 00000000..be79e2fb --- /dev/null +++ b/docs/source/api/algorithms/mo/moeadm2m.rst @@ -0,0 +1,6 @@ +========== +MOEA/D-M2M +========== + +.. autoclass:: evox.algorithms.MOEADM2M + :members: diff --git a/docs/source/api/algorithms/mo/spea2.rst b/docs/source/api/algorithms/mo/spea2.rst new file mode 100644 index 00000000..8d45f88d --- /dev/null +++ b/docs/source/api/algorithms/mo/spea2.rst @@ -0,0 +1,6 @@ +===== +SPEA2 +===== + +.. autoclass:: evox.algorithms.SPEA2 + :members: diff --git a/docs/source/api/algorithms/mo/sra.rst b/docs/source/api/algorithms/mo/sra.rst new file mode 100644 index 00000000..57d10fc6 --- /dev/null +++ b/docs/source/api/algorithms/mo/sra.rst @@ -0,0 +1,6 @@ +=== +SRA +=== + +.. autoclass:: evox.algorithms.SRA + :members: diff --git a/docs/source/api/algorithms/mo/tdea.rst b/docs/source/api/algorithms/mo/tdea.rst new file mode 100644 index 00000000..1319c4ab --- /dev/null +++ b/docs/source/api/algorithms/mo/tdea.rst @@ -0,0 +1,6 @@ +==== +tDEA +==== + +.. autoclass:: evox.algorithms.tDEA + :members: diff --git a/src/evox/algorithms/mo/__init__.py b/src/evox/algorithms/mo/__init__.py index 39a7ddeb..a3747156 100644 --- a/src/evox/algorithms/mo/__init__.py +++ b/src/evox/algorithms/mo/__init__.py @@ -10,4 +10,7 @@ from .moeadm2m import MOEADM2M from .knea import KnEA from .bige import BiGE -from .gde3 import GDE3 \ No newline at end of file +from .gde3 import GDE3 +from .sra import SRA +from .tdea import TDEA +from .bce_ibea import BCEIBEA diff --git a/src/evox/algorithms/mo/bce_ibea.py b/src/evox/algorithms/mo/bce_ibea.py new file mode 100644 index 00000000..1e9d4742 --- /dev/null +++ b/src/evox/algorithms/mo/bce_ibea.py @@ -0,0 +1,333 @@ +import jax +import jax.numpy as jnp + +from evox.operators import ( + non_dominated_sort, + selection, + mutation, + crossover, +) +from evox import Algorithm, jit_class, State +from evox.utils import cal_max, pairwise_euclidean_dist +from functools import partial + + +@jax.jit +def cal_fitness(pop_obj, kappa): + n = jnp.shape(pop_obj)[0] + pop_obj = (pop_obj - jnp.tile(jnp.min(pop_obj), (n, 1))) / ( + jnp.tile(jnp.max(pop_obj) - jnp.min(pop_obj), (n, 1)) + ) + I = cal_max(pop_obj, pop_obj) + + C = jnp.max(jnp.abs(I), axis=0) + + fitness = jnp.sum(-jnp.exp(-I / jnp.tile(C, (n, 1)) / kappa), axis=0) + 1 + + return fitness, I, C + + +@partial(jax.jit, static_argnums=3) +def exploration(pc_obj, npc_obj, n_nd, n): + """ + Pareto criterion evolving + + Args: + pc_obj: Objective values of Pareto criterion solutions. + npc_obj: Objective values of non-Pareto criterion solutions. + n_nd: Number of nondominated solutions. + n: Total number of solutions. + + Returns: + s: Boolean array indicating solutions to be explored. + """ + f_max = jnp.max(pc_obj, axis=0) + f_min = jnp.min(pc_obj, axis=0) + norm_pc_obj = (pc_obj - jnp.tile(f_min, (len(pc_obj), 1))) / jnp.tile( + f_max - f_min, (len(pc_obj), 1) + ) + norm_npc_obj = (npc_obj - jnp.tile(f_min, (len(npc_obj), 1))) / jnp.tile( + f_max - f_min, (len(npc_obj), 1) + ) + + # Determine the size of the niche + distance = pairwise_euclidean_dist(norm_pc_obj, norm_pc_obj) + distance = distance.at[ + jnp.arange(0, len(norm_pc_obj)), jnp.arange(0, len(norm_pc_obj)) + ].set(jnp.inf) + distance = jnp.where(jnp.isnan(distance), jnp.inf, distance) + distance = jnp.sort(distance, axis=1) + # Calculate the characteristic distance r0 for niche detection + r0 = jnp.mean(distance[:, jnp.minimum(2, jnp.shape(distance)[1] - 1)]) + r = n_nd / n * r0 + + # Detect the solutions in PC to be explored + # s: Solutions to be explored + distance = pairwise_euclidean_dist(norm_pc_obj, norm_npc_obj) + s = jnp.sum(distance <= r, axis=1) <= 1 + + return s + + +@partial(jax.jit, static_argnums=2) +def pc_selection(pc, pc_obj, n): + """ + Pareto criterion selection + + Args: + pc : Pareto criterion population. + pc_obj : Objective values of pc. + n : Number of solutions to select. + """ + # m: Number of objectives + # n_nd: Number of non-dominated solutions in PC + m = jnp.shape(pc_obj)[1] + rank = non_dominated_sort(pc_obj) + mask = rank == 0 + n_nd = jnp.sum(mask).astype(int) + mask = mask[:, jnp.newaxis] + next_ind = jnp.zeros(n, dtype=jnp.int32) + i = n_nd + + def true_fun(next_ind, mask, i): + f_max = jnp.max(jnp.where(jnp.tile(mask, (1, m)), pc_obj, -jnp.inf), axis=0) + f_min = jnp.min(jnp.where(jnp.tile(mask, (1, m)), pc_obj, jnp.inf), axis=0) + norm_obj = (pc_obj - jnp.tile(f_min, (len(pc_obj), 1))) / jnp.tile( + f_max - f_min, (len(pc_obj), 1) + ) + norm_obj = jnp.where(jnp.tile(mask, (1, m)), norm_obj, jnp.inf) + distance = pairwise_euclidean_dist(norm_obj, norm_obj) + distance = distance.at[ + jnp.arange(0, len(norm_obj)), jnp.arange(0, len(norm_obj)) + ].set(jnp.inf) + distance = jnp.where(jnp.isnan(distance), jnp.inf, distance) + + # Calculate sorted distance matrix (sd) for each solution + sd = jnp.sort(distance, axis=1) + sd = jnp.where(jnp.tile(mask, (1, len(pc_obj))), sd, 0) + + # Calculate the characteristic distance r for niche detection + r = jnp.sum(sd[:, jnp.minimum(2, jnp.shape(sd)[1] - 1)]) / n_nd + + # Calculate big_r which scales the distance matrix + big_r = jnp.minimum(distance / r, 1) + + def loop(vals): + i, mask, big_r = vals + idx = jnp.argmax(1 - jnp.prod(big_r, axis=0)) + mask = mask.at[idx].set(False) + big_r = big_r.at[idx, :].set(1) + big_r = big_r.at[:, idx].set(1) + return (i - 1, mask, big_r) + + _, mask, big_r = jax.lax.while_loop(lambda x: x[0] > n, loop, (i, mask, big_r)) + pc_indices = jnp.where(mask, size=len(mask), fill_value=-1)[0] + next_ind = pc_indices[:n] + return next_ind, mask, i + + def false_fun(next_ind, mask, i): + pc_indices = jnp.where(mask, size=len(mask), fill_value=-1)[0] + head = pc_indices[0] + pc_indices = jnp.where(pc_indices == -1, head, pc_indices) + next_ind = pc_indices[:n] + return next_ind, mask, i + + next_ind, _, _ = jax.lax.cond(n_nd > n, true_fun, false_fun, next_ind, mask, i) + + return pc[next_ind], pc_obj[next_ind], n_nd + + +@partial(jax.jit, static_argnums=2) +def environmental_selection(pop, obj, n, kappa): + + merged_fitness, I, C = cal_fitness(obj, kappa) + next_ind = jnp.arange(len(pop)) + vals = (next_ind, merged_fitness) + + def body_fun(i, vals): + next_ind, merged_fitness = vals + x = jnp.argmin(merged_fitness) + merged_fitness += jnp.exp(-I[x, :] / C[x] / kappa) + merged_fitness = merged_fitness.at[x].set(jnp.max(merged_fitness)) + next_ind = next_ind.at[x].set(-1) + return (next_ind, merged_fitness) + + next_ind, merged_fitness = jax.lax.fori_loop(0, n, body_fun, vals) + + ind = jnp.where(next_ind != -1, size=len(pop), fill_value=-1)[0] + ind_n = ind[0:n] + + return pop[ind_n], obj[ind_n] + + +@jit_class +class BCEIBEA(Algorithm): + """Bi-criterion evolution based IBEA + + link: https://ieeexplore.ieee.org/abstract/document/7347391 + Inspired by PlatEMO. + + Note: The number of outer iterations needs to be set to Maximum Generation*2+1. + + Args: + kappa (float, optional): The scaling factor for selecting parents in the environmental selection. + It controls the probability of selecting parents based on their fitness values. + Defaults to 0.05. + """ + + def __init__( + self, + lb, + ub, + n_objs, + pop_size, + kappa=0.05, + selection_op=None, + mutation_op=None, + crossover_op=None, + ): + self.lb = lb + self.ub = ub + self.n_objs = n_objs + self.dim = lb.shape[0] + self.pop_size = pop_size + self.kappa = kappa + + self.selection = selection_op + self.mutation = mutation_op + self.crossover = crossover_op + + self.selection = selection.Tournament(n_round=self.pop_size) + if self.mutation is None: + self.mutation = mutation.Polynomial((self.lb, self.ub)) + if self.crossover is None: + self.crossover = crossover.SimulatedBinary() + self.crossover_odd = crossover.SimulatedBinary(type=2) + + def setup(self, key): + key, subkey = jax.random.split(key) + population = ( + jax.random.uniform(subkey, shape=(self.pop_size, self.dim)) + * (self.ub - self.lb) + + self.lb + ) + return State( + population=population, + fitness=jnp.zeros((self.pop_size, self.n_objs)), + npc=population, + npc_obj=jnp.zeros((self.pop_size, self.n_objs)), + new_pc=population, + new_pc_obj=jnp.zeros((self.pop_size, self.n_objs)), + new_npc=population, + new_npc_obj=jnp.zeros((self.pop_size, self.n_objs)), + n_nd=0, + next_generation=population, + is_init=True, + counter=1, + key=key, + ) + + def ask(self, state): + return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state) + + def tell(self, state, fitness): + return jax.lax.cond( + state.is_init, self._tell_init, self._tell_normal, state, fitness + ) + + def _ask_init(self, state): + return state.population, state + + def _ask_normal(self, state): + return jax.lax.cond( + state.counter % 2 == 0, self._ask_even, self._ask_odd, state + ) + + def _ask_odd(self, state): + key, mating_key, x_key, mut_key = jax.random.split(state.key, 4) + s = exploration(state.fitness, state.npc_obj, state.n_nd, self.pop_size) + mating_pool = jax.random.randint( + mating_key, shape=(self.pop_size,), minval=0, maxval=self.pop_size + ) + head = jnp.where(s, size=len(s), fill_value=-1)[0] + mating_pool = jnp.where(s, mating_pool, head[0]) + s_indices = jnp.where(s, jnp.arange(0, len(s)), head[0]) + pop = state.population + indices = jnp.concatenate((s_indices, mating_pool)) + + def true_fun(pop): + mating_pop = pop[indices] + coreeovered = self.crossover_odd(x_key, mating_pop) + offspring = self.mutation(mut_key, coreeovered) + return offspring + + pop = jax.lax.cond(jnp.sum(s) != 0, true_fun, lambda x: x, pop) + + return pop, state.update(new_pc=pop, key=key) + + def _ask_even(self, state): + key, sel_key, x_key, mut_key = jax.random.split(state.key, 4) + fit = -cal_fitness(state.npc_obj, self.kappa)[0] + selected, _ = self.selection(sel_key, state.npc, fit) + crossovered = self.crossover(x_key, selected) + next_generation = self.mutation(mut_key, crossovered) + + return next_generation, state.update(new_npc=next_generation, key=key) + + def _tell_init(self, state, fitness): + pc, pc_obj, n_nd = pc_selection(state.population, fitness, self.pop_size) + state = state.update( + population=pc, + fitness=pc_obj, + npc_obj=fitness, + new_pc_obj=fitness, + new_npc_obj=fitness, + n_nd=n_nd, + is_init=False, + ) + return state + + def _tell_normal(self, state, fitness): + return jax.lax.cond( + state.counter % 2 == 0, self._tell_even, self._tell_odd, state, fitness + ) + + def _tell_odd(self, state, fitness): + new_pc_obj = fitness + merged_pop = jnp.concatenate([state.npc, state.new_pc], axis=0) + merged_fitness = jnp.concatenate([state.npc_obj, fitness], axis=0) + npc, npc_obj = environmental_selection( + merged_pop, merged_fitness, self.pop_size, self.kappa + ) + return state.update( + npc=npc, fitness=npc_obj, counter=state.counter + 1, new_pc_obj=new_pc_obj + ) + + def _tell_even(self, state, fitness): + new_npc_obj = fitness + merged_pop = jnp.concatenate([state.npc, state.new_npc], axis=0) + merged_fitness = jnp.concatenate([state.npc_obj, new_npc_obj], axis=0) + + npc, npc_obj = environmental_selection( + merged_pop, merged_fitness, self.pop_size, self.kappa + ) + + merged_pop = jnp.concatenate( + (state.population, state.new_npc, state.new_pc), axis=0 + ) + merged_fitness = jnp.concatenate( + (state.npc_obj, new_npc_obj, state.new_pc_obj), axis=0 + ) + + pc, pc_obj, n_nd = pc_selection(merged_pop, merged_fitness, self.pop_size) + + state = state.update( + population=pc, + fitness=pc_obj, + n_nd=n_nd, + new_npc_obj=new_npc_obj, + npc=npc, + npc_obj=npc_obj, + counter=state.counter + 1, + ) + return state diff --git a/src/evox/algorithms/mo/bige.py b/src/evox/algorithms/mo/bige.py index db152c79..4007d5e9 100644 --- a/src/evox/algorithms/mo/bige.py +++ b/src/evox/algorithms/mo/bige.py @@ -17,29 +17,32 @@ def estimate(fit, mask): # calculate proximity and crowding degree as bi-goals def calc_sh(fit_a, pr_a, fit_b, pr_b, r): dis = euclidean_dist(fit_a, fit_b) - res = ((dis < r) * 0.5 * - ((1 + (pr_a >= pr_b) + (pr_a > pr_b)) * - (1 - dis / r))) ** 2 + res = ( + (dis < r) * 0.5 * ((1 + (pr_a >= pr_b) + (pr_a > pr_b)) * (1 - dis / r)) + ) ** 2 return res - + n, m = jnp.sum(mask), fit.shape[1] - r = 1 / n**(1 / m) + r = 1 / n ** (1 / m) fit_mask = mask[:, None].repeat(m, axis=1) fit = jnp.where(fit_mask, fit, jnp.nan) f_max = jnp.nanmax(fit, axis=0) f_min = jnp.nanmin(fit, axis=0) normed_fit = (fit - f_min) / (f_max - f_min).clip(1e-6) normed_fit = jnp.where(fit_mask, normed_fit, m) - + # pr: proximity # sh: sharing function # cd: crowding degree pr = jnp.sum(normed_fit, axis=1) - sh = vmap(lambda _f_a, _pr_a, _r: - vmap(lambda _f_b, _pr_b: calc_sh(_f_a, _pr_a, _f_b, _pr_b, _r))(normed_fit, pr), - (0, 0, None))(normed_fit, pr, r) + sh = vmap( + lambda _f_a, _pr_a, _r: vmap( + lambda _f_b, _pr_b: calc_sh(_f_a, _pr_a, _f_b, _pr_b, _r) + )(normed_fit, pr), + (0, 0, None), + )(normed_fit, pr, r) cd = jnp.sqrt(jnp.sum(sh, axis=1) - sh.diagonal()) - + bi_fit = jnp.hstack([pr[:, None], cd[:, None]]) bi_mask = mask[:, None].repeat(2, axis=1) bi_fit = jnp.where(bi_mask, bi_fit, jnp.inf) @@ -106,7 +109,7 @@ def _ask_init(self, state): def _ask_normal(self, state): bi_fit = estimate(state.population, jnp.full((self.pop_size,), True)) bi_rank = non_dominated_sort(bi_fit) - + keys = jax.random.split(state.key, 4) selected, _ = self.selection(keys[1], state.population, bi_rank) crossovered = self.crossover(keys[2], selected) @@ -127,11 +130,11 @@ def _tell_normal(self, state, fitness): ranked_pop = merged_pop[order] ranked_fit = merged_fit[order] last_rank = rank[self.pop_size] - + bi_fit = estimate(ranked_fit, rank == last_rank) bi_rank = non_dominated_sort(bi_fit) - + fin_rank = jnp.where(rank >= last_rank, bi_rank, -1) idx = jnp.argsort(fin_rank)[: self.pop_size] state = state.update(population=ranked_pop[idx], fitness=ranked_fit[idx]) - return state \ No newline at end of file + return state diff --git a/src/evox/algorithms/mo/eagmoead.py b/src/evox/algorithms/mo/eagmoead.py index 720b6b12..49977772 100644 --- a/src/evox/algorithms/mo/eagmoead.py +++ b/src/evox/algorithms/mo/eagmoead.py @@ -3,7 +3,13 @@ from functools import partial from evox import jit_class, Algorithm, State -from evox.operators import selection, mutation, crossover, non_dominated_sort, crowding_distance +from evox.operators import ( + selection, + mutation, + crossover, + non_dominated_sort, + crowding_distance, +) from evox.operators.sampling import UniformSampling, LatinHypercubeSampling from evox.utils import pairwise_euclidean_dist @@ -15,7 +21,7 @@ def environmental_selection(fitness, n): worst_rank = rank[order[n - 1]] mask = rank == worst_rank crowding_dis = crowding_distance(fitness, mask) - combined_indices = jnp.lexsort((-crowding_dis, rank))[: n] + combined_indices = jnp.lexsort((-crowding_dis, rank))[:n] return combined_indices @@ -25,6 +31,7 @@ class EAGMOEAD(Algorithm): """EAG-MOEA/D algorithm link: https://ieeexplore.ieee.org/abstract/document/6882229 + Inspired by PlatEMO. """ def __init__( @@ -81,7 +88,7 @@ def setup(self, key): B=B, s=jnp.zeros((self.pop_size, self.LGs)), parent=jnp.zeros((self.pop_size, self.T)).astype(int), - offspring_loc=jnp.zeros((self.pop_size, )).astype(int), + offspring_loc=jnp.zeros((self.pop_size,)).astype(int), gen=0, is_init=True, key=key, @@ -107,11 +114,9 @@ def _ask_normal(self, state): d = s / jnp.sum(s) + 0.002 d = d / jnp.sum(d) - _, offspring_loc = self.selection(sel_key, population, 1./d) + _, offspring_loc = self.selection(sel_key, population, 1.0 / d) parent = jnp.zeros((n, 2)).astype(int) - B = jax.random.permutation( - per_key, B, axis=1, independent=True - ).astype(int) + B = jax.random.permutation(per_key, B, axis=1, independent=True).astype(int) def body_fun(i, val): val = val.at[i, 0].set(B[offspring_loc[i], 0]) @@ -151,7 +156,9 @@ def _tell_normal(self, state, fitness): def body_fun(i, vals): population, pop_obj = vals - g_old = jnp.sum(pop_obj[B[offspring_loc[i], :]] * w[B[offspring_loc[i], :]], axis=1) + g_old = jnp.sum( + pop_obj[B[offspring_loc[i], :]] * w[B[offspring_loc[i], :]], axis=1 + ) g_new = w[B[offspring_loc[i], :]] @ jnp.transpose(offspring_obj[i]) idx = B[offspring_loc[i]] g_new = g_new[:, jnp.newaxis] @@ -177,18 +184,24 @@ def body_fun(i, vals): sucessful = jnp.where(mask, size=self.pop_size) def update_s(s): - h = offspring_loc[combined_order[sucessful]-self.pop_size] + h = offspring_loc[combined_order[sucessful] - self.pop_size] head = h[0] h = jnp.where(h == head, -1, h) h = h.at[0].set(head) hist, _ = jnp.histogram(h, self.pop_size, range=(0, self.pop_size)) - s = s.at[:, gen % self.LGs+1].set(hist) + s = s.at[:, gen % self.LGs + 1].set(hist) return s def no_update(s): return s s = jax.lax.cond(num_valid != 0, update_s, no_update, s) - state = state.update(population=survivor, fitness=survivor_fitness, inner_pop=inner_pop, inner_obj=inner_obj, - s=s, gen=gen) + state = state.update( + population=survivor, + fitness=survivor_fitness, + inner_pop=inner_pop, + inner_obj=inner_obj, + s=s, + gen=gen, + ) return state diff --git a/src/evox/algorithms/mo/gde3.py b/src/evox/algorithms/mo/gde3.py index 1e8c58fa..1091a17f 100644 --- a/src/evox/algorithms/mo/gde3.py +++ b/src/evox/algorithms/mo/gde3.py @@ -26,7 +26,7 @@ def __init__( pop_size, F=0.49, CR=0.97, - ): + ): """ Parameters for Differential Evolution ---------- @@ -95,4 +95,4 @@ def _tell_normal(self, state, fitness): survivor = merged_pop[combined_order] survivor_fitness = merged_fit[combined_order] state = state.update(population=survivor, fitness=survivor_fitness) - return state \ No newline at end of file + return state diff --git a/src/evox/algorithms/mo/hype.py b/src/evox/algorithms/mo/hype.py index d58a24dc..239a761b 100644 --- a/src/evox/algorithms/mo/hype.py +++ b/src/evox/algorithms/mo/hype.py @@ -12,7 +12,7 @@ def calculate_alpha(N, k): for i in range(1, k + 1): num = jnp.prod((k - jnp.arange(1, i)) / (N - jnp.arange(1, i))) - alpha = alpha.at[i-1].set(num / i) + alpha = alpha.at[i - 1].set(num / i) return alpha @@ -26,13 +26,13 @@ def cal_hv(points, ref, k, n_sample, key): s = jax.random.uniform(key, shape=(n_sample, m), minval=f_min, maxval=ref) pds = jnp.zeros((n, n_sample), dtype=bool) - ds = jnp.zeros((n_sample, )) + ds = jnp.zeros((n_sample,)) def body_fun1(i, vals): pds, ds = vals x = jnp.sum((jnp.tile(points[i, :], (n_sample, 1)) - s) <= 0, axis=1) == m pds = pds.at[i].set(jnp.where(x, True, pds[i])) - ds = jnp.where(x, ds+1, ds) + ds = jnp.where(x, ds + 1, ds) return pds, ds pds, ds = jax.lax.fori_loop(0, n, body_fun1, (pds, ds)) @@ -42,7 +42,7 @@ def body_fun1(i, vals): def body_fun2(i, val): temp = jnp.where(pds[i, :], ds, -1).astype(int) - value = jnp.where(temp!=-1, alpha[temp], 0) + value = jnp.where(temp != -1, alpha[temp], 0) value = jnp.sum(value) val = val.at[i].set(value) return val @@ -52,11 +52,13 @@ def body_fun2(i, val): return f + @jit_class class HypE(Algorithm): """HypE algorithm link: https://direct.mit.edu/evco/article-abstract/19/1/45/1363/HypE-An-Algorithm-for-Fast-Hypervolume-Based-Many + Inspired by PlatEMO. """ def __init__( @@ -95,7 +97,7 @@ def setup(self, key): population=population, fitness=jnp.zeros((self.pop_size, self.n_objs)), next_generation=population, - ref_point=jnp.zeros((self.n_objs, )), + ref_point=jnp.zeros((self.n_objs,)), key=key, is_init=True, ) @@ -124,7 +126,7 @@ def _ask_normal(self, state): return next_generation, state.update(next_generation=next_generation) def _tell_init(self, state, fitness): - ref_point = jnp.zeros((self.n_objs, )) + jnp.max(fitness)*1.2 + ref_point = jnp.zeros((self.n_objs,)) + jnp.max(fitness) * 1.2 state = state.update(fitness=fitness, ref_point=ref_point, is_init=False) return state @@ -136,7 +138,7 @@ def _tell_normal(self, state, fitness): rank = non_dominated_sort(merged_obj) order = jnp.argsort(rank) - worst_rank = rank[order[n-1]] + worst_rank = rank[order[n - 1]] mask = rank == worst_rank key, subkey = jax.random.split(state.key) diff --git a/src/evox/algorithms/mo/ibea.py b/src/evox/algorithms/mo/ibea.py index 7d5c4448..d7c040a2 100644 --- a/src/evox/algorithms/mo/ibea.py +++ b/src/evox/algorithms/mo/ibea.py @@ -26,6 +26,7 @@ class IBEA(Algorithm): """IBEA algorithm link: https://link.springer.com/chapter/10.1007/978-3-540-30217-9_84 + Inspired by PlatEMO. """ def __init__( @@ -86,19 +87,16 @@ def _ask_normal(self, state): pop_obj = state.fitness fitness = cal_fitness(pop_obj, self.kappa)[0] - # selected = self.selection(sel_key, population, fitness) selected, _ = self.selection(sel_key, population, -fitness) crossovered = self.crossover(x_key, selected) next_generation = self.mutation(mut_key, crossovered) - # next_generation = jnp.clip(mutated, self.lb, self.ub) return next_generation, state.update(next_generation=next_generation, key=key) def _tell_init(self, state, fitness): state = state.update(fitness=fitness, is_init=False) return state - # @profile def _tell_normal(self, state, fitness): merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0) merged_obj = jnp.concatenate([state.fitness, fitness], axis=0) diff --git a/src/evox/algorithms/mo/knea.py b/src/evox/algorithms/mo/knea.py index cd495cd5..ac2c7423 100644 --- a/src/evox/algorithms/mo/knea.py +++ b/src/evox/algorithms/mo/knea.py @@ -91,7 +91,7 @@ def _ask_init(self, state): def _ask_normal(self, state): rank = non_dominated_sort(state.fitness) DW = calc_DW(state.fitness, self.k_neighbors) - + keys = jax.random.split(state.key, 4) selected, _ = self.selection(keys[1], state.population, -DW, ~state.knee, rank) crossovered = self.crossover(keys[2], selected) @@ -216,4 +216,4 @@ def too_few(info): r=r, t=t, ) - return state \ No newline at end of file + return state diff --git a/src/evox/algorithms/mo/moead.py b/src/evox/algorithms/mo/moead.py index 50d483e2..a3941c79 100644 --- a/src/evox/algorithms/mo/moead.py +++ b/src/evox/algorithms/mo/moead.py @@ -12,6 +12,7 @@ class MOEAD(Algorithm): """MOEA/D algorithm link: https://ieeexplore.ieee.org/document/4358754 + Inspired by PlatEMO. """ def __init__( diff --git a/src/evox/algorithms/mo/moeaddra.py b/src/evox/algorithms/mo/moeaddra.py index 827b1633..4516ecbe 100644 --- a/src/evox/algorithms/mo/moeaddra.py +++ b/src/evox/algorithms/mo/moeaddra.py @@ -7,12 +7,12 @@ from evox.utils import pairwise_euclidean_dist - @jit_class class MOEADDRA(Algorithm): """MOEA/D-DRA algorithm link: https://ieeexplore.ieee.org/abstract/document/4982949 + Inspired by PlatEMO. """ def __init__( @@ -64,10 +64,10 @@ def setup(self, key): weight_vector=w, B=B, Z=jnp.zeros(shape=self.n_objs), - pi=jnp.ones((self.pop_size, )), - old_obj=jnp.zeros((self.pop_size, )), + pi=jnp.ones((self.pop_size,)), + old_obj=jnp.zeros((self.pop_size,)), choosed_p=jnp.zeros((self.pop_size, self.T)).astype(int), - I_all=jnp.zeros((self.pop_size, )).astype(int), + I_all=jnp.zeros((self.pop_size,)).astype(int), gen=0, is_init=True, key=key, @@ -86,12 +86,16 @@ def _ask_init(self, state): def _ask_normal(self, state): - key, subkey1, subkey2, subkey3, sel_key, x_key, mut_key = jax.random.split(state.key, 7) + key, subkey1, subkey2, subkey3, sel_key, x_key, mut_key = jax.random.split( + state.key, 7 + ) parent = jax.random.permutation( subkey1, state.B, axis=1, independent=True ).astype(int) rand = jax.random.uniform(subkey2, (self.pop_size, 1)) - rand_perm = jax.random.randint(subkey3, (self.pop_size, self.T), 0, self.pop_size) + rand_perm = jax.random.randint( + subkey3, (self.pop_size, self.T), 0, self.pop_size + ) w = state.weight_vector pi = state.pi population = state.population @@ -100,14 +104,19 @@ def _ask_normal(self, state): mask = jnp.sum(w < 1e-3, axis=1) == (self.n_objs - 1) boundary = jnp.where(mask, size=self.pop_size, fill_value=0)[0] - boundary = boundary[:self.i_size] - g_bound = jnp.tile(boundary, (5, )) + boundary = boundary[: self.i_size] + g_bound = jnp.tile(boundary, (5,)) I_all = jnp.where(g_bound != 0, g_bound, selected_idx) choosed_p = jnp.where(rand < 0.9, parent[I_all], rand_perm) - crossovered = self.crossover(x_key, population[I_all], population[choosed_p[:, 0]], population[choosed_p[:, 1]]) + crossovered = self.crossover( + x_key, + population[I_all], + population[choosed_p[:, 0]], + population[choosed_p[:, 1]], + ) next_generation = self.mutation(mut_key, crossovered) return next_generation, state.update( @@ -116,7 +125,10 @@ def _ask_normal(self, state): def _tell_init(self, state, fitness): Z = jnp.min(fitness, axis=0) - old_obj = jnp.max(jnp.abs((fitness - jnp.tile(Z, (self.pop_size, 1))) * state.weight_vector), axis=1) + old_obj = jnp.max( + jnp.abs((fitness - jnp.tile(Z, (self.pop_size, 1))) * state.weight_vector), + axis=1, + ) state = state.update(fitness=fitness, Z=Z, old_obj=old_obj, is_init=False) return state @@ -142,13 +154,15 @@ def out_body(i, out_vals): ind_obj = off_obj[i] Z = jnp.minimum(Z, ind_obj) - g_old = jnp.max(jnp.abs(pop_obj[p] - jnp.tile(Z, (len(p), 1))) * w[p], axis=1) - g_new = jnp.max(jnp.abs(jnp.tile(ind_obj-Z, (len(p), 1))) * w[p], axis=1) + g_old = jnp.max( + jnp.abs(pop_obj[p] - jnp.tile(Z, (len(p), 1))) * w[p], axis=1 + ) + g_new = jnp.max(jnp.abs(jnp.tile(ind_obj - Z, (len(p), 1))) * w[p], axis=1) g_new = g_new[:, jnp.newaxis] g_old = g_old[:, jnp.newaxis] - indices = jnp.where(g_old >= g_new, size=len(p))[0][:self.nr] + indices = jnp.where(g_old >= g_new, size=len(p))[0][: self.nr] population = population.at[p[indices]].set(ind_dec) pop_obj = pop_obj.at[p[indices]].set(ind_obj) @@ -158,10 +172,12 @@ def out_body(i, out_vals): population, pop_obj, Z = jax.lax.fori_loop(0, self.pop_size, out_body, out_vals) def update_pi(pi, old_obj): - new_obj = jnp.max(jnp.abs((pop_obj - jnp.tile(Z, (self.pop_size, 1))) * w), axis=1) + new_obj = jnp.max( + jnp.abs((pop_obj - jnp.tile(Z, (self.pop_size, 1))) * w), axis=1 + ) delta = (old_obj - new_obj) / old_obj mask = delta < 0.001 - pi = jnp.where(mask, pi*(0.95 + 0.05*delta/0.001), 1) + pi = jnp.where(mask, pi * (0.95 + 0.05 * delta / 0.001), 1) old_obj = new_obj return pi, old_obj @@ -176,5 +192,12 @@ def no_update(pi, old_obj): old_obj, ) - state = state.update(population=population, fitness=pop_obj, Z=Z, gen=current_gen, pi=pi, old_obj=old_obj) + state = state.update( + population=population, + fitness=pop_obj, + Z=Z, + gen=current_gen, + pi=pi, + old_obj=old_obj, + ) return state diff --git a/src/evox/algorithms/mo/moeadm2m.py b/src/evox/algorithms/mo/moeadm2m.py index 9b4a523b..4762b3fc 100644 --- a/src/evox/algorithms/mo/moeadm2m.py +++ b/src/evox/algorithms/mo/moeadm2m.py @@ -13,25 +13,34 @@ def __call__(self, key, p1, p2, scale): n, d = jnp.shape(p1) subkey1, subkey2 = jax.random.split(key) - rc = (2*jax.random.uniform(subkey1, (n, 1))-1) * (1-jax.random.uniform(subkey2, (n, 1)))**(-(1-scale)**0.7) + rc = (2 * jax.random.uniform(subkey1, (n, 1)) - 1) * ( + 1 - jax.random.uniform(subkey2, (n, 1)) + ) ** (-((1 - scale) ** 0.7)) offspring = p1 + jnp.tile(rc, (1, d)) * (p1 - p2) return offspring @jit_class class Mutation: - def __call__(self, key, p1, off, scale, lb, ub): n, d = jnp.shape(p1) subkey1, subkey2, subkey3, subkey4 = jax.random.split(key, 4) - rm = 0.25 * (2*jax.random.uniform(subkey1, (n, d))-1) * (1-jax.random.uniform(subkey2, (n, d)))**(-(1-scale)**0.7) - site = jax.random.uniform(subkey3, (n, d)) < (1/d) + rm = ( + 0.25 + * (2 * jax.random.uniform(subkey1, (n, d)) - 1) + * (1 - jax.random.uniform(subkey2, (n, d))) ** (-((1 - scale) ** 0.7)) + ) + site = jax.random.uniform(subkey3, (n, d)) < (1 / d) lower = jnp.tile(lb, (n, 1)) upper = jnp.tile(ub, (n, 1)) offspring = jnp.where(site, off + rm * (upper - lower), off) rnd = jax.random.uniform(subkey4, (n, d)) - offspring = jnp.where(offspring < lower, lower + 0.5 * rnd * (p1 - lower), offspring) - offspring = jnp.where(offspring > upper, upper - 0.5 * rnd * (upper - p1), offspring) + offspring = jnp.where( + offspring < lower, lower + 0.5 * rnd * (p1 - lower), offspring + ) + offspring = jnp.where( + offspring > upper, upper - 0.5 * rnd * (upper - p1), offspring + ) return offspring @@ -40,7 +49,6 @@ def associate(rng, pop, obj, w, s): k = len(w) dis = cos_dist(obj, w) max_indices = jnp.argmax(dis, axis=1) - # id_print(max_indices) partition = jnp.zeros((s, k), dtype=int) def body_fun(i, p): @@ -49,7 +57,7 @@ def body_fun(i, p): def true_fun(c): c = c[:s] - rad = jax.random.randint(rng, (s, ), 0, len(pop)) + rad = jax.random.randint(rng, (s,), 0, len(pop)) c = jnp.where(c != -1, c, rad) return c @@ -60,7 +68,7 @@ def false_fun(c): worst_rank = rank[order[s - 1]] mask_worst = rank == worst_rank crowding_dis = crowding_distance(obj, mask_worst) - c = jnp.lexsort((-crowding_dis, rank))[: s] + c = jnp.lexsort((-crowding_dis, rank))[:s] return c current = jax.lax.cond(jnp.sum(mask) < s, true_fun, false_fun, current) @@ -69,7 +77,7 @@ def false_fun(c): partition = jax.lax.fori_loop(0, k, body_fun, partition) - partition = partition.flatten(order='F') + partition = partition.flatten(order="F") return pop[partition], obj[partition] @@ -78,6 +86,7 @@ class MOEADM2M(Algorithm): """MOEA/D based on MOP to MOP algorithm link: https://ieeexplore.ieee.org/abstract/document/6595549 + Inspired by PlatEMO. """ def __init__( @@ -122,7 +131,7 @@ def setup(self, key): w=w, is_init=True, key=key, - gen=0 + gen=0, ) def ask(self, state): @@ -137,21 +146,30 @@ def _ask_init(self, state): return state.population, state def _ask_normal(self, state): - key, local_key, global_key, rnd_key, x_key, mut_key = jax.random.split(state.key, 6) + key, local_key, global_key, rnd_key, x_key, mut_key = jax.random.split( + state.key, 6 + ) current_gen = state.gen scale = current_gen / self.max_gen population = state.population - mating_pool_local = jax.random.randint(local_key, (self.s, self.k), 0, self.s) + \ - jnp.tile(jnp.arange(0, self.s * self.k, self.s), (self.s, 1)) + mating_pool_local = jax.random.randint( + local_key, (self.s, self.k), 0, self.s + ) + jnp.tile(jnp.arange(0, self.s * self.k, self.s), (self.s, 1)) mating_pool_local = mating_pool_local.flatten() - mating_pool_global = jax.random.randint(global_key, (self.pop_size, ), 0, self.pop_size) + mating_pool_global = jax.random.randint( + global_key, (self.pop_size,), 0, self.pop_size + ) rnd = jax.random.uniform(rnd_key, (self.s, self.k)).flatten() mating_pool_local = jnp.where(rnd < 0.7, mating_pool_global, mating_pool_local) - crossovered = self.crossover(x_key, population, population[mating_pool_local], scale) - next_generation = self.mutation(mut_key, population, crossovered, scale, self.lb, self.ub) + crossovered = self.crossover( + x_key, population, population[mating_pool_local], scale + ) + next_generation = self.mutation( + mut_key, population, crossovered, scale, self.lb, self.ub + ) current_gen = current_gen + 1 return next_generation, state.update( @@ -163,7 +181,9 @@ def _tell_init(self, state, fitness): population = state.population population, fitness = associate(subkey, population, fitness, state.w, self.s) - state = state.update(population=population, fitness=fitness, is_init=False, key=key) + state = state.update( + population=population, fitness=fitness, is_init=False, key=key + ) return state def _tell_normal(self, state, fitness): @@ -171,7 +191,9 @@ def _tell_normal(self, state, fitness): merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0) merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0) - population, pop_obj = associate(subkey, merged_pop, merged_fitness, state.w, self.s) + population, pop_obj = associate( + subkey, merged_pop, merged_fitness, state.w, self.s + ) state = state.update(population=population, fitness=pop_obj, key=key) return state diff --git a/src/evox/algorithms/mo/nsga3.py b/src/evox/algorithms/mo/nsga3.py index 6a2e7926..826509d9 100644 --- a/src/evox/algorithms/mo/nsga3.py +++ b/src/evox/algorithms/mo/nsga3.py @@ -48,7 +48,7 @@ def __init__( self.mutation = mutation.Gaussian() if self.crossover is None: self.crossover = crossover.UniformRand() - + self.ref = self.ref / jnp.linalg.norm(self.ref, axis=1)[:, None] def setup(self, key): @@ -63,7 +63,7 @@ def setup(self, key): fitness=jnp.zeros((self.pop_size, self.n_objs)), next_generation=population, is_init=True, - key=key + key=key, ) def ask(self, state): diff --git a/src/evox/algorithms/mo/rvea.py b/src/evox/algorithms/mo/rvea.py index b4e8cc6d..6c771977 100644 --- a/src/evox/algorithms/mo/rvea.py +++ b/src/evox/algorithms/mo/rvea.py @@ -2,76 +2,11 @@ import jax.numpy as jnp from evox.utils import cos_dist -from evox.operators import mutation, crossover +from evox.operators import mutation, crossover, selection from evox.operators.sampling import UniformSampling, LatinHypercubeSampling from evox import Algorithm, State, jit_class -@jax.jit -def ref_vec_guided(x, v, theta): - n, m = jnp.shape(x) - nv = jnp.shape(v)[0] - obj = x - - obj -= jnp.tile(jnp.min(obj, axis=0), (n, 1)) - - cosine = cos_dist(v, v) - cosine = jnp.where(jnp.eye(jnp.shape(cosine)[0], dtype=bool), 0, cosine) - cosine = jnp.clip(cosine, -1, 1) - gamma = jnp.min(jnp.arccos(cosine), axis=1) - - angle = jnp.arccos(cos_dist(obj, v)) - - associate = jnp.argmin(angle, axis=1) - - next_ind = jnp.full(nv, -1) - is_null = jnp.sum(next_ind) - global_min = jnp.inf - global_min_idx = -1 - - vals = next_ind, global_min, global_min_idx - - def update_next(i, sub_index, next_ind, global_min, global_min_idx): - apd = (1 + m * theta * angle[sub_index, i] / gamma[i]) * jnp.sqrt(jnp.sum(obj[sub_index, :] ** 2, axis=1)) - - apd_max = jnp.max(apd) - noise = jnp.where(sub_index == -1, apd_max, 0) - apd = apd + noise - best = jnp.argmin(apd) - - global_min_idx = jnp.where(apd[best] < global_min, sub_index[best.astype(int)], global_min_idx) - global_min = jnp.minimum(apd[best], global_min) - - next_ind = next_ind.at[i].set(sub_index[best.astype(int)]) - return next_ind, global_min, global_min_idx - - def no_update(i, sub_index, next_ind, global_min, global_min_idx): - return next_ind, global_min, global_min_idx - - def body_fun(i, vals): - next_ind, global_min, global_min_idx = vals - sub_index = jnp.where(associate == i, size=nv, fill_value=-1)[0] - - next_ind, global_min, global_min_idx = jax.lax.cond(jnp.sum(sub_index) != is_null, update_next, no_update, i, - sub_index, next_ind, global_min, global_min_idx) - return next_ind, global_min, global_min_idx - - next_ind, global_min, global_min_idx = jax.lax.fori_loop(0, nv, body_fun, vals) - mask = next_ind == -1 - - next_ind = jnp.where(mask, global_min_idx, next_ind) - next_ind = jnp.where(global_min_idx != -1, next_ind, jnp.arange(0, nv)) - - return next_ind - - -@jit_class -class ReferenceVectorGuided: - """Reference vector guided environmental selection.""" - def __call__(self, x, v, theta): - return ref_vec_guided(x, v, theta) - - @jit_class class RVEA(Algorithm): """RVEA algorithms @@ -106,7 +41,7 @@ def __init__( self.crossover = crossover_op if self.selection is None: - self.selection = ReferenceVectorGuided() + self.selection = selection.ReferenceVectorGuided() if self.mutation is None: self.mutation = mutation.Polynomial((lb, ub)) if self.crossover is None: @@ -131,7 +66,7 @@ def setup(self, key): reference_vector=v, is_init=True, key=key, - gen=0 + gen=0, ) def ask(self, state): diff --git a/src/evox/algorithms/mo/spea2.py b/src/evox/algorithms/mo/spea2.py index f45665d2..fbf0dd8d 100644 --- a/src/evox/algorithms/mo/spea2.py +++ b/src/evox/algorithms/mo/spea2.py @@ -13,7 +13,9 @@ def cal_fitness(obj): dom_matrix = _dominate_relation(obj, obj) s = jnp.sum(dom_matrix, axis=1) - r = jax.vmap(lambda s, d: jnp.sum(jnp.where(d, s, 0)), in_axes=(None, 1), out_axes=0)(s, dom_matrix) + r = jax.vmap( + lambda s, d: jnp.sum(jnp.where(d, s, 0)), in_axes=(None, 1), out_axes=0 + )(s, dom_matrix) dis = pairwise_euclidean_dist(obj, obj) diagonal_indices = jnp.arange(n) @@ -21,7 +23,7 @@ def cal_fitness(obj): dis = jnp.sort(dis, axis=1) d = 1 / (dis[:, jnp.floor(jnp.sqrt(6)).astype(int) - 1] + 2) - return d+r + return d + r @jax.jit @@ -59,16 +61,17 @@ class SPEA2(Algorithm): """SPEA2 algorithm link: https://www.research-collection.ethz.ch/handle/20.500.11850/145755 + Inspired by PlatEMO. """ def __init__( - self, - lb, - ub, - n_objs, - pop_size, - mutation_op=None, - crossover_op=None, + self, + lb, + ub, + n_objs, + pop_size, + mutation_op=None, + crossover_op=None, ): self.lb = lb self.ub = ub @@ -88,9 +91,9 @@ def __init__( def setup(self, key): key, subkey = jax.random.split(key) population = ( - jax.random.uniform(subkey, shape=(self.pop_size, self.dim)) - * (self.ub - self.lb) - + self.lb + jax.random.uniform(subkey, shape=(self.pop_size, self.dim)) + * (self.ub - self.lb) + + self.lb ) return State( population=population, @@ -140,13 +143,13 @@ def fitness_sort(mask): return order def truncation(mask): - order = _truncation(merged_fitness, num_valid-self.pop_size, mask) + order = _truncation(merged_fitness, num_valid - self.pop_size, mask) order = jnp.where(order, size=len(mask))[0] return order order = jax.lax.cond(num_valid <= self.pop_size, fitness_sort, truncation, mask) - combined_order = order[:self.pop_size] + combined_order = order[: self.pop_size] survivor = merged_pop[combined_order] survivor_fitness = merged_fitness[combined_order] diff --git a/src/evox/algorithms/mo/sra.py b/src/evox/algorithms/mo/sra.py new file mode 100644 index 00000000..69d12964 --- /dev/null +++ b/src/evox/algorithms/mo/sra.py @@ -0,0 +1,190 @@ +import jax +import jax.numpy as jnp + +from evox.operators import mutation, crossover +from evox.utils import cal_max +from evox import Algorithm, State, jit_class +from functools import partial + + +@jax.jit +def stochastic_ranking_selection(key, pc, I1, I2): + n = len(I1) + n_half = jnp.ceil(n / 2).astype(int) + rank = jnp.arange(0, n) + + def swap_indices(rank, j): + temp = rank[j] + rank = rank.at[j].set(rank[j + 1]) + rank = rank.at[j + 1].set(temp) + return rank + + i = 0 + swapdone = True + rnd_all = jax.random.uniform(key, (n - 1,)) + + def body_fun(vals): + rank, swapdone, i = vals + swapdone = False + + def in_body(j, vals): + rank, swapdone = vals + rnd = rnd_all[j] + + def true_fun(rank, swapdone): + def in_true(rank, swapdone): + rank = swap_indices(rank, j) + swapdone = True + return rank, swapdone + + def in_false(rank, swapdone): + return rank, swapdone + + rank, swapdone = jax.lax.cond( + I1[rank[j]] < I1[rank[j + 1]], in_true, in_false, rank, swapdone + ) + + return rank, swapdone + + def false_fun(rank, swapdone): + def in_true(rank, swapdone): + rank = swap_indices(rank, j) + swapdone = True + return rank, swapdone + + def in_false(rank, swapdone): + return rank, swapdone + + rank, swapdone = jax.lax.cond( + I2[rank[j]] < I2[rank[j + 1]], in_true, in_false, rank, swapdone + ) + return rank, swapdone + + rank, swapdone = jax.lax.cond(rnd < pc, true_fun, false_fun, rank, swapdone) + return (rank, swapdone) + + rank, swapdone = jax.lax.fori_loop(0, n - 1, in_body, (rank, swapdone)) + i += 1 + return rank, swapdone, i + + rank, s, a = jax.lax.while_loop( + lambda vals: (vals[2] < n_half) & (vals[1]), body_fun, (rank, swapdone, i) + ) + return rank + + +@partial(jax.jit, static_argnums=3) +def environmental_selection(key, pop, obj, k, uni_rnd): + n = len(pop) + + I = cal_max(obj, obj) + I1 = jnp.sum(-jnp.exp(-I / 0.05), axis=0) + 1 + + dis = jnp.full((n, n), fill_value=jnp.inf) + + def out_body(i, val): + s_obj = jnp.maximum(obj, jnp.tile(obj[i, :], (n, 1))) + + def in_body(j, d): + d = d.at[i, j].set(jnp.linalg.norm(obj[i, :] - s_obj[j, :])) + return d + + d = jax.lax.fori_loop(0, i, in_body, val) + return d + + dis = jax.lax.fori_loop(0, n, out_body, dis) + I2 = jnp.min(dis, axis=1) + + rank = stochastic_ranking_selection(key, uni_rnd, I1, I2) + indices = rank[:k] + return pop[indices], obj[indices] + + +@jit_class +class SRA(Algorithm): + """Stochastic ranking algorithm + + link: https://ieeexplore.ieee.org/abstract/document/7445185 + Inspired by PlatEMO. + """ + + def __init__( + self, + lb, + ub, + n_objs, + pop_size, + mutation_op=None, + crossover_op=None, + ): + self.lb = lb + self.ub = ub + self.n_objs = n_objs + self.dim = lb.shape[0] + self.pop_size = pop_size + + self.mutation = mutation_op + self.crossover = crossover_op + + if self.mutation is None: + self.mutation = mutation.Polynomial((lb, ub)) + if self.crossover is None: + self.crossover = crossover.SimulatedBinary(type=2) + + def setup(self, key): + key, subkey = jax.random.split(key) + population = ( + jax.random.uniform(subkey, shape=(self.pop_size, self.dim)) + * (self.ub - self.lb) + + self.lb + ) + + return State( + population=population, + fitness=jnp.zeros((self.pop_size, self.n_objs)), + next_generation=population, + is_init=True, + key=key, + ) + + def ask(self, state): + return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state) + + def tell(self, state, fitness): + return jax.lax.cond( + state.is_init, self._tell_init, self._tell_normal, state, fitness + ) + + def _ask_init(self, state): + return state.population, state + + def _ask_normal(self, state): + key, sel_key, x_key, mut_key = jax.random.split(state.key, 4) + + mating_pool = jax.random.randint( + sel_key, (self.pop_size * 2,), 0, self.pop_size + ) + population = state.population[mating_pool] + + crossovered = self.crossover(sel_key, population) + next_generation = self.mutation(mut_key, crossovered) + + return next_generation, state.update(next_generation=next_generation, key=key) + + def _tell_init(self, state, fitness): + state = state.update(fitness=fitness, is_init=False) + return state + + def _tell_normal(self, state, fitness): + merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0) + merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0) + + key, subkey, env_key = jax.random.split(state.key, 3) + pc = jax.random.uniform(subkey) * (0.6 - 0.4) + 0.4 + + population, pop_obj = environmental_selection( + env_key, merged_pop, merged_fitness, self.pop_size, pc + ) + + state = state.update(population=population, fitness=pop_obj) + return state diff --git a/src/evox/algorithms/mo/tdea.py b/src/evox/algorithms/mo/tdea.py new file mode 100644 index 00000000..0932ce34 --- /dev/null +++ b/src/evox/algorithms/mo/tdea.py @@ -0,0 +1,176 @@ +import jax +import jax.numpy as jnp + +from evox.operators import mutation, crossover, sampling +from evox.operators import non_dominated_sort +from evox.utils import cos_dist +from evox import Algorithm, State, jit_class +from functools import partial + + +@jax.jit +def theta_nd_sort(obj, w, mask): + # theta non-dominated sort + n = len(obj) + nw = len(w) + + norm_p = jnp.sqrt(jnp.sum(obj**2, axis=1, keepdims=True)) + cosine = cos_dist(obj, w) + cosine = jnp.clip(cosine, -1, 1) + d1 = jnp.tile(norm_p, (1, nw)) * cosine + d2 = jnp.tile(norm_p, (1, nw)) * jnp.sqrt(1 - cosine**2) + + d_class = jnp.argmin(d2, axis=1) + d_class = jnp.where(mask, d_class, jnp.inf) + + theta = jnp.zeros((nw,)) + 5 + theta = jnp.where(jnp.sum(w > 1e-4, axis=1) == 1, 1e6, theta) + t_rank = jnp.zeros((n,), dtype=int) + + t_list = jnp.arange(1, n + 1) + + def loop_body(i, val): + t_front_no = val + loop_mask = d_class == i + d = jnp.where(loop_mask, d1[:, i] + theta[i] * d2[:, i], jnp.inf) + rank = jnp.argsort(d) + tmp = jnp.where(loop_mask[rank], t_list, t_front_no[rank]) + t_front_no = t_front_no.at[rank].set(tmp) + + return t_front_no + + t_rank = jax.lax.fori_loop(0, nw, loop_body, t_rank) + t_rank = jnp.where(mask, t_rank, jnp.inf) + + return t_rank + + +@partial(jax.jit, static_argnums=3) +def environmental_selection(pop, obj, w, n, z, z_nad): + + n_merge, m = jnp.shape(obj) + rank = non_dominated_sort(obj) + order = jnp.argsort(rank) + worst_rank = rank[order[n]] + mask = rank <= worst_rank + + z = jnp.minimum(z, jnp.min(obj, axis=0)) + + w1 = jnp.zeros((m, m)) + 1e-6 + w1 = jnp.where(jnp.eye(m), 1, w1) + asf = jax.vmap( + lambda x, y: jnp.max(jnp.abs((x - z) / (z_nad - z)) / y, axis=1), + in_axes=(None, 0), + out_axes=1, + )(obj, w1) + + extreme = jnp.argmin(asf, axis=0) + hyper_plane = jnp.linalg.solve( + obj[extreme, :] - jnp.tile(z, (m, 1)), jnp.ones((m, 1)) + ) + a = z + 1 / jnp.squeeze(hyper_plane) + + a = jax.lax.cond( + jnp.any(jnp.isnan(a)) | jnp.any(a <= z), + lambda _: jnp.max(obj, axis=0), + lambda x: x, + a, + ) + z_nad = a + + norm_obj = (obj - jnp.tile(z, (n_merge, 1))) / jnp.tile(z_nad - z, (n_merge, 1)) + + t_rank = theta_nd_sort(norm_obj, w, mask) + combined_order = jnp.lexsort((t_rank, rank))[:n] + + return pop[combined_order], obj[combined_order], z, z_nad + + +@jit_class +class TDEA(Algorithm): + """Theta-dominance based evolutionary algorithm + + link: https://ieeexplore.ieee.org/abstract/document/7080938 + Inspired by PlatEMO. + """ + + def __init__( + self, + lb, + ub, + n_objs, + pop_size, + mutation_op=None, + crossover_op=None, + ): + self.lb = lb + self.ub = ub + self.n_objs = n_objs + self.dim = lb.shape[0] + self.pop_size = pop_size + + self.mutation = mutation_op + self.crossover = crossover_op + + if self.mutation is None: + self.mutation = mutation.Polynomial((lb, ub)) + if self.crossover is None: + self.crossover = crossover.SimulatedBinary() + self.sampling = sampling.LatinHypercubeSampling(self.pop_size, self.n_objs) + + def setup(self, key): + key, subkey1, subkey2 = jax.random.split(key, 3) + population = ( + jax.random.uniform(subkey1, shape=(self.pop_size, self.dim)) + * (self.ub - self.lb) + + self.lb + ) + w = self.sampling(subkey2)[0] + return State( + population=population, + fitness=jnp.zeros((self.pop_size, self.n_objs)), + next_generation=population, + w=w, + z=jnp.zeros((self.n_objs,)), + z_nad=jnp.zeros((self.n_objs,)), + is_init=True, + key=key, + ) + + def ask(self, state): + return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state) + + def tell(self, state, fitness): + return jax.lax.cond( + state.is_init, self._tell_init, self._tell_normal, state, fitness + ) + + def _ask_init(self, state): + return state.population, state + + def _ask_normal(self, state): + key, sel_key, x_key, mut_key = jax.random.split(state.key, 4) + + mating_pool = jax.random.randint(sel_key, (self.pop_size,), 0, self.pop_size) + population = state.population[mating_pool] + crossovered = self.crossover(x_key, population) + next_generation = self.mutation(mut_key, crossovered) + + return next_generation, state.update(next_generation=next_generation, key=key) + + def _tell_init(self, state, fitness): + z = jnp.min(fitness, axis=0) + z_nad = jnp.max(fitness, axis=0) + state = state.update(fitness=fitness, z=z, z_nad=z_nad, is_init=False) + return state + + def _tell_normal(self, state, fitness): + merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0) + merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0) + + population, pop_obj, z, z_nad = environmental_selection( + merged_pop, merged_fitness, state.w, self.pop_size, state.z, state.z_nad + ) + + state = state.update(population=population, fitness=pop_obj, z=z, z_nad=z_nad) + return state diff --git a/src/evox/operators/crossover/sbx.py b/src/evox/operators/crossover/sbx.py index aadc3148..b54449f6 100644 --- a/src/evox/operators/crossover/sbx.py +++ b/src/evox/operators/crossover/sbx.py @@ -9,7 +9,7 @@ def simulated_binary(key, x, pro_c, dis_c, type): mu_key, beta1_key, beta2_key, beta3_key = random.split(key, 4) n, _ = jnp.shape(x) parent1_dec = x[: n // 2, :] - parent2_dec = x[n // 2: n // 2 * 2, :] + parent2_dec = x[n // 2 : n // 2 * 2, :] n_p, d = jnp.shape(parent1_dec) beta = jnp.zeros((n_p, d)) mu = random.uniform(mu_key, shape=(n_p, d)) @@ -43,6 +43,7 @@ def simulated_binary(key, x, pro_c, dis_c, type): @jit_class class SimulatedBinary: """Simulated binary crossover(SBX) + Inspired by PlatEMO. Args: pro_c: the probabilities of doing crossover. diff --git a/src/evox/operators/mutation/pm_mutation.py b/src/evox/operators/mutation/pm_mutation.py index a4fdd7df..cfb37bf8 100644 --- a/src/evox/operators/mutation/pm_mutation.py +++ b/src/evox/operators/mutation/pm_mutation.py @@ -57,6 +57,7 @@ def polynomial(key, x, boundary, pro_m, dis_m): @jit_class class Polynomial: """Polynomial mutation + Inspired by PlatEMO. Args: pro_m: the expectation of number of bits doing mutation. diff --git a/src/evox/operators/sampling/__init__.py b/src/evox/operators/sampling/__init__.py index 761ee832..fbec5f91 100644 --- a/src/evox/operators/sampling/__init__.py +++ b/src/evox/operators/sampling/__init__.py @@ -1,2 +1,2 @@ from .uniform import UniformSampling -from .latin_hypercude import LatinHypercubeSampling \ No newline at end of file +from .latin_hypercude import LatinHypercubeSampling diff --git a/src/evox/operators/sampling/latin_hypercude.py b/src/evox/operators/sampling/latin_hypercude.py index f1f1a610..cf1cceed 100644 --- a/src/evox/operators/sampling/latin_hypercude.py +++ b/src/evox/operators/sampling/latin_hypercude.py @@ -14,9 +14,10 @@ def __init__(self, n=None, m=None): def __call__(self, key): subkeys = jax.random.split(key, self.m) w = jax.random.uniform(key, shape=(self.n, self.m)) - parm = jnp.tile(jnp.arange(1, self.n+1), (self.m, 1)) - parm = jax.vmap(jax.random.permutation, in_axes=(0, 0), out_axes=1)(subkeys, parm) + parm = jnp.tile(jnp.arange(1, self.n + 1), (self.m, 1)) + parm = jax.vmap(jax.random.permutation, in_axes=(0, 0), out_axes=1)( + subkeys, parm + ) w = (parm - w) / self.n n = self.n return w, n - diff --git a/src/evox/operators/sampling/uniform.py b/src/evox/operators/sampling/uniform.py index dd01cf29..e088cddc 100644 --- a/src/evox/operators/sampling/uniform.py +++ b/src/evox/operators/sampling/uniform.py @@ -8,7 +8,10 @@ @evox.jit_class class UniformSampling: - """Uniform sampling use Das and Dennis's method, Deb and Jain's method.""" + """ + Uniform sampling use Das and Dennis's method, Deb and Jain's method. + Inspired by PlatEMO's NBI algorithm. + """ def __init__(self, n=None, m=None): self.n = n @@ -19,21 +22,39 @@ def __call__(self): while comb(h1 + self.m, self.m - 1) <= self.n: h1 += 1 - w = jnp.array(list(n_choose_k(range(1, h1 + self.m), self.m-1))) - \ - jnp.tile(jnp.array(range(self.m-1)), (comb(h1+self.m-1, self.m-1).astype(int), 1)) - 1 - w = (jnp.c_[w, jnp.zeros((jnp.shape(w)[0], 1)) + h1] - - jnp.c_[jnp.zeros((jnp.shape(w)[0], 1)), w]) / h1 + w = ( + jnp.array(list(n_choose_k(range(1, h1 + self.m), self.m - 1))) + - jnp.tile( + jnp.array(range(self.m - 1)), + (comb(h1 + self.m - 1, self.m - 1).astype(int), 1), + ) + - 1 + ) + w = ( + jnp.c_[w, jnp.zeros((jnp.shape(w)[0], 1)) + h1] + - jnp.c_[jnp.zeros((jnp.shape(w)[0], 1)), w] + ) / h1 if h1 < self.m: h2 = 0 - while comb(h1+self.m-1, self.m-1) + comb(h2+self.m, self.m-1) <= self.n: + while ( + comb(h1 + self.m - 1, self.m - 1) + comb(h2 + self.m, self.m - 1) + <= self.n + ): h2 += 1 if h2 > 0: - w2 = jnp.array(list(n_choose_k(range(1, h2+self.m), self.m-1))) - \ - jnp.tile(jnp.array(range(self.m - 1)), (comb(h2+self.m-1, self.m-1).astype(int), 1)) - 1 - w2 = (jnp.c_[w2, jnp.zeros((jnp.shape(w2)[0], 1))+h2] - - jnp.c_[jnp.zeros((jnp.shape(w2)[0], 1)), w2]) / h2 - w = jnp.r_[w, w2/2. + 1./(2.*self.m)] + w2 = ( + jnp.array(list(n_choose_k(range(1, h2 + self.m), self.m - 1))) + - jnp.tile( + jnp.array(range(self.m - 1)), + (comb(h2 + self.m - 1, self.m - 1).astype(int), 1), + ) + - 1 + ) + w2 = ( + jnp.c_[w2, jnp.zeros((jnp.shape(w2)[0], 1)) + h2] + - jnp.c_[jnp.zeros((jnp.shape(w2)[0], 1)), w2] + ) / h2 + w = jnp.r_[w, w2 / 2.0 + 1.0 / (2.0 * self.m)] w = jnp.maximum(w, 1e-6) n = jnp.shape(w)[0] return w, n - diff --git a/src/evox/operators/selection/roulette_wheel.py b/src/evox/operators/selection/roulette_wheel.py index 1fb4f442..fabaedbd 100644 --- a/src/evox/operators/selection/roulette_wheel.py +++ b/src/evox/operators/selection/roulette_wheel.py @@ -16,12 +16,11 @@ def __init__(self, n): def __call__(self, key, x, fitness): fitness = fitness - jnp.minimum(jnp.min(fitness), 0) + 1e-6 - fitness = jnp.cumsum(1. / fitness) + fitness = jnp.cumsum(1.0 / fitness) fitness = fitness / jnp.max(fitness) - random_values = jax.random.uniform(key, shape=(self.n, )) + random_values = jax.random.uniform(key, shape=(self.n,)) selected_indices = jnp.searchsorted(fitness, random_values) return x[selected_indices], selected_indices - diff --git a/src/evox/operators/selection/rvea_selection.py b/src/evox/operators/selection/rvea_selection.py index 51a75d8d..b2a9ff79 100644 --- a/src/evox/operators/selection/rvea_selection.py +++ b/src/evox/operators/selection/rvea_selection.py @@ -14,6 +14,7 @@ def ref_vec_guided(x, v, theta): cosine = cos_dist(v, v) cosine = jnp.where(jnp.eye(jnp.shape(cosine)[0], dtype=bool), 0, cosine) + cosine = jnp.clip(cosine, -1, 1) gamma = jnp.min(jnp.arccos(cosine), axis=1) angle = jnp.arccos(cos_dist(obj, v)) @@ -22,28 +23,53 @@ def ref_vec_guided(x, v, theta): next_ind = jnp.full(nv, -1) is_null = jnp.sum(next_ind) + global_min = jnp.inf + global_min_idx = -1 - def update_next(i, sub_index, next_ind): + vals = next_ind, global_min, global_min_idx + + def update_next(i, sub_index, next_ind, global_min, global_min_idx): apd = (1 + m * theta * angle[sub_index, i] / gamma[i]) * jnp.sqrt( jnp.sum(obj[sub_index, :] ** 2, axis=1) ) + apd_max = jnp.max(apd) noise = jnp.where(sub_index == -1, apd_max, 0) - best = jnp.argmin(apd + noise) + apd = apd + noise + best = jnp.argmin(apd) + + global_min_idx = jnp.where( + apd[best] < global_min, sub_index[best.astype(int)], global_min_idx + ) + global_min = jnp.minimum(apd[best], global_min) + next_ind = next_ind.at[i].set(sub_index[best.astype(int)]) - return next_ind + return next_ind, global_min, global_min_idx - def no_update(_i, _sub_index, next_ind): - return next_ind + def no_update(i, sub_index, next_ind, global_min, global_min_idx): + return next_ind, global_min, global_min_idx - def body_fun(i, val): + def body_fun(i, vals): + next_ind, global_min, global_min_idx = vals sub_index = jnp.where(associate == i, size=nv, fill_value=-1)[0] - next_ind = lax.cond( - jnp.sum(sub_index) != is_null, update_next, no_update, i, sub_index, val + + next_ind, global_min, global_min_idx = lax.cond( + jnp.sum(sub_index) != is_null, + update_next, + no_update, + i, + sub_index, + next_ind, + global_min, + global_min_idx, ) - return next_ind + return next_ind, global_min, global_min_idx - next_ind = lax.fori_loop(0, nv, body_fun, next_ind) + next_ind, global_min, global_min_idx = lax.fori_loop(0, nv, body_fun, vals) + mask = next_ind == -1 + + next_ind = jnp.where(mask, global_min_idx, next_ind) + next_ind = jnp.where(global_min_idx != -1, next_ind, jnp.arange(0, nv)) return next_ind @@ -51,5 +77,6 @@ def body_fun(i, val): @jit_class class ReferenceVectorGuided: """Reference vector guided environmental selection.""" + def __call__(self, x, v, theta): return ref_vec_guided(x, v, theta) diff --git a/src/evox/problems/classic/dtlz.py b/src/evox/problems/classic/dtlz.py index 8d03d067..0e16f59c 100644 --- a/src/evox/problems/classic/dtlz.py +++ b/src/evox/problems/classic/dtlz.py @@ -46,13 +46,23 @@ def evaluate(self, state: chex.PyTreeDef, X: chex.Array): m = self.m n, d = jnp.shape(X) - g = 100 * (d - m + 1 + jnp.sum( - (X[:, m - 1:] - 0.5) ** 2 - - jnp.cos(20 * jnp.pi * (X[:, m - 1:] - 0.5)), - axis=1, keepdims=True)) - f = 0.5 * jnp.tile(1 + g, (1, m)) * jnp.fliplr(jnp.cumprod( - jnp.c_[jnp.ones((n, 1)), X[:, :m - 1]], axis=1)) * \ - jnp.c_[jnp.ones((n, 1)), 1 - X[:, m - 2::-1]] + g = 100 * ( + d + - m + + 1 + + jnp.sum( + (X[:, m - 1 :] - 0.5) ** 2 + - jnp.cos(20 * jnp.pi * (X[:, m - 1 :] - 0.5)), + axis=1, + keepdims=True, + ) + ) + f = ( + 0.5 + * jnp.tile(1 + g, (1, m)) + * jnp.fliplr(jnp.cumprod(jnp.c_[jnp.ones((n, 1)), X[:, : m - 1]], axis=1)) + * jnp.c_[jnp.ones((n, 1)), 1 - X[:, m - 2 :: -1]] + ) return f, state @@ -70,17 +80,28 @@ def __init__(self, d=None, m=None, ref_num=1000): def evaluate(self, state: chex.PyTreeDef, X: chex.Array): m = self.m - g = jnp.sum((X[:, m - 1:] - 0.5) ** 2, axis=1, keepdims=True) - f = jnp.tile(1 + g, (1, m)) * jnp.fliplr(jnp.cumprod(jnp.c_[jnp.ones((jnp.shape(g)[0], 1)), - jnp.cos(X[:, :m - 1] * jnp.pi / 2)], axis=1)) * jnp.c_[jnp.ones((jnp.shape(g)[0], 1)), jnp.sin( - X[:, m - 2::-1] * jnp.pi / 2)] + g = jnp.sum((X[:, m - 1 :] - 0.5) ** 2, axis=1, keepdims=True) + f = ( + jnp.tile(1 + g, (1, m)) + * jnp.fliplr( + jnp.cumprod( + jnp.c_[ + jnp.ones((jnp.shape(g)[0], 1)), + jnp.cos(X[:, : m - 1] * jnp.pi / 2), + ], + axis=1, + ) + ) + * jnp.c_[ + jnp.ones((jnp.shape(g)[0], 1)), jnp.sin(X[:, m - 2 :: -1] * jnp.pi / 2) + ] + ) return f, state def pf(self, state: chex.PyTreeDef): f = self.sample()[0] - f /= jnp.tile(jnp.sqrt(jnp.sum(f ** 2, axis=1, - keepdims=True)), (1, self.m)) + f /= jnp.tile(jnp.sqrt(jnp.sum(f**2, axis=1, keepdims=True)), (1, self.m)) return f, state @@ -91,12 +112,29 @@ def __init__(self, d=None, m=None, ref_num=1000): def evaluate(self, state: chex.PyTreeDef, X: chex.Array): n, d = jnp.shape(X) m = self.m - g = 100 * (d - m + 1 + jnp.sum( - ((X[:, m - 1:] - 0.5) ** 2 - - jnp.cos(20 * jnp.pi * (X[:, m - 1:] - 0.5))), - axis=1, keepdims=True)) - f = jnp.tile(1 + g, (1, m)) * jnp.fliplr(jnp.cumprod(jnp.c_[jnp.ones((n, 1)), - jnp.cos(X[:, :m - 1] * jnp.pi / 2)], axis=1)) * jnp.c_[jnp.ones((n, 1)), jnp.sin(X[:, m - 2::-1] * jnp.pi / 2)] + g = 100 * ( + d + - m + + 1 + + jnp.sum( + ( + (X[:, m - 1 :] - 0.5) ** 2 + - jnp.cos(20 * jnp.pi * (X[:, m - 1 :] - 0.5)) + ), + axis=1, + keepdims=True, + ) + ) + f = ( + jnp.tile(1 + g, (1, m)) + * jnp.fliplr( + jnp.cumprod( + jnp.c_[jnp.ones((n, 1)), jnp.cos(X[:, : m - 1] * jnp.pi / 2)], + axis=1, + ) + ) + * jnp.c_[jnp.ones((n, 1)), jnp.sin(X[:, m - 2 :: -1] * jnp.pi / 2)] + ) return f, state @@ -107,13 +145,23 @@ def __init__(self, d=None, m=None, ref_num=1000): def evaluate(self, state: chex.PyTreeDef, X: chex.Array): m = self.m - X = X.at[:, :m - 1].power(100) - g = jnp.sum((X[:, m - 1:] - 0.5) ** 2, axis=1, keepdims=True) - f = jnp.tile(1 + g, (1, m)) * jnp.fliplr(jnp.cumprod(jnp.c_[jnp.ones((jnp.shape(g)[0], 1)), - jnp.cos(X[:, :m - 1] * jnp.pi / 2)], - axis=1)) * \ - jnp.c_[jnp.ones((jnp.shape(g)[0], 1)), jnp.sin( - X[:, m - 2::-1] * jnp.pi / 2)] + X = X.at[:, : m - 1].power(100) + g = jnp.sum((X[:, m - 1 :] - 0.5) ** 2, axis=1, keepdims=True) + f = ( + jnp.tile(1 + g, (1, m)) + * jnp.fliplr( + jnp.cumprod( + jnp.c_[ + jnp.ones((jnp.shape(g)[0], 1)), + jnp.cos(X[:, : m - 1] * jnp.pi / 2), + ], + axis=1, + ) + ) + * jnp.c_[ + jnp.ones((jnp.shape(g)[0], 1)), jnp.sin(X[:, m - 2 :: -1] * jnp.pi / 2) + ] + ) return f, state @@ -133,29 +181,49 @@ def __init__(self, d=None, m=None, ref_num=1000): def evaluate(self, state: chex.PyTreeDef, X: chex.Array): m = self.m - g = jnp.sum((X[:, m - 1:] - 0.5) ** 2, axis=1, keepdims=True) + g = jnp.sum((X[:, m - 1 :] - 0.5) ** 2, axis=1, keepdims=True) temp = jnp.tile(g, (1, m - 2)) - X = X.at[:, 1:m - 1].set((1 + 2 * temp * - X[:, 1:m - 1]) / (2 + 2 * temp)) - f = jnp.tile(1 + g, (1, m)) * jnp.fliplr(jnp.cumprod(jnp.c_[jnp.ones((jnp.shape(g)[0], 1)), - jnp.cos(X[:, :m - 1] * jnp.pi / 2)], - axis=1)) * \ - jnp.c_[jnp.ones((jnp.shape(g)[0], 1)), jnp.sin( - X[:, m - 2::-1] * jnp.pi / 2)] + X = X.at[:, 1 : m - 1].set((1 + 2 * temp * X[:, 1 : m - 1]) / (2 + 2 * temp)) + f = ( + jnp.tile(1 + g, (1, m)) + * jnp.fliplr( + jnp.cumprod( + jnp.c_[ + jnp.ones((jnp.shape(g)[0], 1)), + jnp.cos(X[:, : m - 1] * jnp.pi / 2), + ], + axis=1, + ) + ) + * jnp.c_[ + jnp.ones((jnp.shape(g)[0], 1)), jnp.sin(X[:, m - 2 :: -1] * jnp.pi / 2) + ] + ) return f, state def pf(self, state: chex.PyTreeDef): n = self.ref_num * self.m - f = jnp.vstack((jnp.hstack(((jnp.arange(0, 1, 1. / (n - 1))), 1.)), - jnp.hstack(((jnp.arange(1, 0, -1. / (n - 1))), 0.)))).T - f /= jnp.tile(jnp.sqrt(jnp.sum(f ** 2, axis=1, - keepdims=True)), (1, jnp.shape(f)[1])) + f = jnp.vstack( + ( + jnp.hstack(((jnp.arange(0, 1, 1.0 / (n - 1))), 1.0)), + jnp.hstack(((jnp.arange(1, 0, -1.0 / (n - 1))), 0.0)), + ) + ).T + f /= jnp.tile( + jnp.sqrt(jnp.sum(f**2, axis=1, keepdims=True)), (1, jnp.shape(f)[1]) + ) for i in range(self.m - 2): f = jnp.c_[f[:, 0], f] - f = f / jnp.sqrt(2) * jnp.tile(jnp.hstack((self.m - 2, - jnp.arange(self.m - 2, -1, -1))), (jnp.shape(f)[0], 1)) + f = ( + f + / jnp.sqrt(2) + * jnp.tile( + jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))), + (jnp.shape(f)[0], 1), + ) + ) return f, state @@ -173,30 +241,50 @@ def __init__(self, d=None, m=None, ref_num=1000): def evaluate(self, state: chex.PyTreeDef, X: chex.Array): m = self.m - g = jnp.sum((X[:, m - 1:] ** 0.1), axis=1, keepdims=True) + g = jnp.sum((X[:, m - 1 :] ** 0.1), axis=1, keepdims=True) temp = jnp.tile(g, (1, m - 2)) - X = X.at[:, 1:m - 1].set((1 + 2 * temp * - X[:, 1:m - 1]) / (2 + 2 * temp)) - - f = jnp.tile(1 + g, (1, m)) * jnp.fliplr(jnp.cumprod(jnp.c_[jnp.ones((jnp.shape(g)[0], 1)), - jnp.cos(X[:, :m - 1] * jnp.pi / 2)], - axis=1)) * \ - jnp.c_[jnp.ones((jnp.shape(g)[0], 1)), jnp.sin( - X[:, m - 2::-1] * jnp.pi / 2)] + X = X.at[:, 1 : m - 1].set((1 + 2 * temp * X[:, 1 : m - 1]) / (2 + 2 * temp)) + + f = ( + jnp.tile(1 + g, (1, m)) + * jnp.fliplr( + jnp.cumprod( + jnp.c_[ + jnp.ones((jnp.shape(g)[0], 1)), + jnp.cos(X[:, : m - 1] * jnp.pi / 2), + ], + axis=1, + ) + ) + * jnp.c_[ + jnp.ones((jnp.shape(g)[0], 1)), jnp.sin(X[:, m - 2 :: -1] * jnp.pi / 2) + ] + ) return f, state def pf(self, state: chex.PyTreeDef): n = self.ref_num * self.m - f = jnp.vstack((jnp.hstack(((jnp.arange(0, 1, 1. / (n - 1))), 1.)), - jnp.hstack(((jnp.arange(1, 0, -1. / (n - 1))), 0.)))).T - f /= jnp.tile(jnp.sqrt(jnp.sum(f ** 2, axis=1, - keepdims=True)), (1, jnp.shape(f)[1])) + f = jnp.vstack( + ( + jnp.hstack(((jnp.arange(0, 1, 1.0 / (n - 1))), 1.0)), + jnp.hstack(((jnp.arange(1, 0, -1.0 / (n - 1))), 0.0)), + ) + ).T + f /= jnp.tile( + jnp.sqrt(jnp.sum(f**2, axis=1, keepdims=True)), (1, jnp.shape(f)[1]) + ) for i in range(self.m - 2): f = jnp.c_[f[:, 0], f] - f = f / jnp.sqrt(2) * jnp.tile(jnp.hstack((self.m - 2, - jnp.arange(self.m - 2, -1, -1))), (jnp.shape(f)[0], 1)) + f = ( + f + / jnp.sqrt(2) + * jnp.tile( + jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))), + (jnp.shape(f)[0], 1), + ) + ) return f, state @@ -216,11 +304,21 @@ def evaluate(self, state: chex.PyTreeDef, X: chex.Array): n, d = jnp.shape(X) m = self.m f = jnp.zeros((n, m)) - g = 1 + 9 * jnp.mean(X[:, m - 1:], axis=1, keepdims=True) - f = f.at[:, :m - 1].set(X[:, :m - 1]) - f = f.at[:, m - 1:].set((1 + g) * (m - jnp.sum(f[:, :m - 1] / (1 + jnp.tile(g, (1, m - 1))) - * (1 + jnp.sin(3 * jnp.pi * f[:, :m - 1])), axis=1, - keepdims=True))) + g = 1 + 9 * jnp.mean(X[:, m - 1 :], axis=1, keepdims=True) + f = f.at[:, : m - 1].set(X[:, : m - 1]) + f = f.at[:, m - 1 :].set( + (1 + g) + * ( + m + - jnp.sum( + f[:, : m - 1] + / (1 + jnp.tile(g, (1, m - 1))) + * (1 + jnp.sin(3 * jnp.pi * f[:, : m - 1])), + axis=1, + keepdims=True, + ) + ) + ) return f, state def pf(self, state: chex.PyTreeDef): diff --git a/tests/test_multi_objective_algorithms.py b/tests/test_multi_objective_algorithms.py index b78808b6..dcdda6a9 100644 --- a/tests/test_multi_objective_algorithms.py +++ b/tests/test_multi_objective_algorithms.py @@ -160,3 +160,33 @@ def test_gde3(): pop_size=POP_SIZE, ) run_moea(algorithm) + + +def test_sra(): + algorithm = algorithms.SRA( + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), + n_objs=3, + pop_size=100, + ) + run_moea(algorithm) + + +def test_tdea(): + algorithm = algorithms.TDEA( + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), + n_objs=3, + pop_size=100, + ) + run_moea(algorithm) + + +def test_bce_ibea(): + algorithm = algorithms.BCEIBEA( + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), + n_objs=3, + pop_size=100, + ) + run_moea(algorithm)