Skip to content

Commit

Permalink
Merge pull request #145 from EMI-Group/state_io
Browse files Browse the repository at this point in the history
State based IO
  • Loading branch information
BillHuang2001 authored Oct 23, 2024
2 parents c993546 + 7ed05a3 commit 8285b29
Show file tree
Hide file tree
Showing 24 changed files with 771 additions and 825 deletions.
58 changes: 33 additions & 25 deletions src/evox/algorithms/so/pso_variants/pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,30 @@
# Link: https://ieeexplore.ieee.org/document/494215
# --------------------------------------------------------------------------------------

from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp
import copy

from evox import Algorithm, State, dataclass, pytree_field
from evox.utils import *
from evox import Algorithm, State, jit_class


@jit_class
@dataclass
class PSO(Algorithm):
def __init__(
self,
lb,
ub,
pop_size,
inertia_weight=0.6,
cognitive_coefficient=2.5,
social_coefficient=0.8,
mean=None,
stdev=None,
):
self.dim = lb.shape[0]
self.lb = lb
self.ub = ub
self.pop_size = pop_size
self.w = inertia_weight
self.phi_p = cognitive_coefficient
self.phi_g = social_coefficient
self.mean = mean
self.stdev = stdev
dim: jax.Array = pytree_field(static=True, init=False)
lb: jax.Array
ub: jax.Array
pop_size: jax.Array = pytree_field(static=True)
w: jax.Array = pytree_field(default=0.6)
phi_p: jax.Array = pytree_field(default=2.5)
phi_g: jax.Array = pytree_field(default=0.8)
mean: Optional[jax.Array] = pytree_field(default=None)
stdev: Optional[jax.Array] = pytree_field(default=None)
bound_method: str = pytree_field(static=True, default="clip")

def __post_init__(self):
self.set_frozen_attr("dim", self.lb.shape[0])

def setup(self, key):
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
Expand Down Expand Up @@ -95,7 +87,23 @@ def tell(self, state, fitness):
+ self.phi_g * rg * (global_best_location - state.population)
)
population = state.population + velocity
population = jnp.clip(population, self.lb, self.ub)

if self.bound_method == "clip":
population = jnp.clip(population, self.lb, self.ub)
elif self.bound_method == "reflect":
lower_bound_violation = population < self.lb
upper_bound_violation = population > self.ub

population = jnp.where(
lower_bound_violation, 2 * self.lb - population, population
)
population = jnp.where(
upper_bound_violation, 2 * self.ub - population, population
)

velocity = jnp.where(
lower_bound_violation | upper_bound_violation, -velocity, velocity
)

return state.replace(
population=population,
Expand Down
70 changes: 60 additions & 10 deletions src/evox/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def wrapper(self, state: State, *args, **kwargs):
new_state,
)

state = state.replace_by_path(path, new_state)
state = state.replace_by_path(
path, new_state.clear_callbacks()
).prepend_closure(new_state)

if aux is None:
return state
Expand Down Expand Up @@ -148,6 +150,10 @@ class Stateful:
The ``init`` method will automatically call the ``setup`` of the current module
and recursively call ``setup`` methods of all submodules.
Currently, there are two special metadata that can be used to control the behavior of the module initialization:
- ``stack``: If set to True, the module will be initialized multiple times, and the states will be stacked together.
- ``nested``: If set to True, the a list of modules, that is [module1, module2, ...], will be iterated and initialized.
"""

def __init__(self):
Expand All @@ -174,10 +180,16 @@ def setup(self, key: jax.Array) -> State:
return State()

def _recursive_init(
self, key: jax.Array, node_id: int, module_name: str, no_state: bool
self,
key: jax.Array,
node_id: int,
module_name: str,
no_state: bool,
re_init: bool,
) -> Tuple[State, int]:
object.__setattr__(self, "_node_id", node_id)
object.__setattr__(self, "_module_name", module_name)
if not re_init:
object.__setattr__(self, "_node_id", node_id)
object.__setattr__(self, "_module_name", module_name)

if not no_state:
child_states = {}
Expand All @@ -197,6 +209,15 @@ def _recursive_init(

if isinstance(attr, Stateful):
submodules.append(SubmoduleInfo(field.name, attr, field.metadata))

# handle "nested" field
if field.metadata.get("nested", False):
for idx, nested_module in enumerate(attr):
submodules.append(
SubmoduleInfo(
field.name + str(idx), nested_module, field.metadata
)
)
else:
for attr_name in vars(self):
attr = getattr(self, attr_name)
Expand All @@ -211,24 +232,27 @@ def _recursive_init(
else:
key, subkey = jax.random.split(key)

# handle "StackAnnotation"
# handle "Stack"
# attr should be a list, or tuple of modules
if metadata.get("stack", False):
num_copies = len(attr)
subkeys = jax.random.split(subkey, num_copies)
current_node_id = node_id
_, node_id = attr._recursive_init(None, node_id + 1, attr_name, True)
_, node_id = attr._recursive_init(
None, node_id + 1, attr_name, True, re_init
)
submodule_state, _node_id = jax.vmap(
partial(
Stateful._recursive_init,
node_id=current_node_id + 1,
module_name=attr_name,
no_state=no_state,
re_init=re_init,
)
)(attr, subkeys)
else:
submodule_state, node_id = attr._recursive_init(
subkey, node_id + 1, attr_name, no_state
subkey, node_id + 1, attr_name, no_state, re_init
)

if not no_state:
Expand All @@ -246,10 +270,12 @@ def _recursive_init(

self_state._set_state_id_mut(self._node_id)._set_child_states_mut(
child_states
),
)
return self_state, node_id

def init(self, key: jax.Array = None, no_state: bool = False) -> State:
def init(
self, key: jax.Array = None, no_state: bool = False, re_init: bool = False
) -> State:
"""Initialize this module and all submodules
This method should not be overwritten.
Expand All @@ -264,9 +290,33 @@ def init(self, key: jax.Array = None, no_state: bool = False) -> State:
State
The state of this module and all submodules combined.
"""
state, _node_id = self._recursive_init(key, 0, None, no_state)
state, _node_id = self._recursive_init(key, 0, None, no_state, re_init)
return state

def parallel_init(
self, key: jax.Array, num_copies: int, no_state: bool = False
) -> Tuple[State, int]:
"""Initialize multiple copies of this module in parallel
This method should not be overwritten.
Parameters
----------
key
A PRNGKey.
num_copies
The number of copies to be initialized
no_state
Whether to skip the state initialization
Returns
-------
Tuple[State, int]
The state of this module and all submodules combined, and the last node_id
"""
subkeys = jax.random.split(key, num_copies)
return jax.vmap(self.init, in_axes=(0, None))(subkeys, no_state)

@classmethod
def stack(cls, stateful_objs, axis=0):
for obj in stateful_objs:
Expand Down
5 changes: 4 additions & 1 deletion src/evox/core/monitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
class Monitor:
from .module import *


class Monitor(Stateful):
"""Monitor base class.
Monitors are used to monitor the evolutionary process.
They contains a set of callbacks,
Expand Down
26 changes: 21 additions & 5 deletions src/evox/core/pytree_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from jax.tree_util import register_pytree_node
import copy
import dataclasses
from typing import Annotated, Any, Callable, Optional, Tuple, TypeVar, get_type_hints

from typing_extensions import (
dataclass_transform, # pytype: disable=not-supported-yet
)
from jax.tree_util import register_pytree_node
from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet

from .distributed import ShardingType

Expand All @@ -19,10 +18,26 @@ def pytree_field(
return dataclasses.field(**kwargs)


def _dataclass_set_frozen_attr(self, key, value):
object.__setattr__(self, key, value)


def _dataclass_replace(self, **kwargs):
"""Add a replace method to dataclasses.
It's different from dataclasses.replace in that it doesn't call the __init__,
instead it copies the object and sets the new values.
"""
new_obj = copy.copy(self)
for key, value in kwargs.items():
object.__setattr__(new_obj, key, value)
return new_obj


def dataclass(cls, *args, **kwargs):
"""
A dataclass decorator that also registers the dataclass as a pytree node.
"""
kwargs = {"unsafe_hash": False, "eq": False, **kwargs}
cls = dataclasses.dataclass(cls, *args, **kwargs)

field_info = []
Expand Down Expand Up @@ -78,7 +93,8 @@ def unflatten(aux_data, children):
register_pytree_node(cls, flatten, unflatten)

# Add a method to set frozen attributes after init
cls.set_frozen_attr = lambda self, key, value: object.__setattr__(self, key, value)
cls.set_frozen_attr = _dataclass_set_frozen_attr
cls.replace = _dataclass_replace
return cls


Expand Down
Loading

0 comments on commit 8285b29

Please sign in to comment.