Skip to content

Commit

Permalink
dev: add parallel_init and parallel_step as a sugar
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Sep 9, 2024
1 parent 1bee54c commit ce5e6cd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
24 changes: 24 additions & 0 deletions src/evox/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,30 @@ def init(self, key: jax.Array = None, no_state: bool = False) -> State:
state, _node_id = self._recursive_init(key, 0, None, no_state)
return state

def parallel_init(
self, key: jax.Array, num_copies: int, no_state: bool = False
) -> Tuple[State, int]:
"""Initialize multiple copies of this module in parallel
This method should not be overwritten.
Parameters
----------
key
A PRNGKey.
num_copies
The number of copies to be initialized
no_state
Whether to skip the state initialization
Returns
-------
Tuple[State, int]
The state of this module and all submodules combined, and the last node_id
"""
subkeys = jax.random.split(key, num_copies)
return jax.vmap(self.init, in_axes=(0, None))(subkeys, no_state)

@classmethod
def stack(cls, stateful_objs, axis=0):
for obj in stateful_objs:
Expand Down
34 changes: 25 additions & 9 deletions src/evox/workflows/std_workflow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable, Sequence
from typing import Optional, Union
from typing import Optional, Tuple, Union
from functools import partial
import warnings

import jax
Expand All @@ -23,6 +24,13 @@
from evox.utils import parse_opt_direction


def _leftover_callbacks_warning(method_name):
warnings.warn(
f"`{method_name}` is called with a state that has leftover callbacks. "
"Did you forget to call `execute_callbacks`?"
)


@dataclass
class StdWorkflowState:
generation: int
Expand Down Expand Up @@ -189,11 +197,12 @@ def _step(self, state):

return state

self._step = partial(_step, self)
self._parallel_step = jax.vmap(self._step)
if self.jit_step:
# the first argument is self, which should be static
self._step = jax.jit(_step, static_argnums=(0,))
else:
self._step = _step
self._step = jax.jit(self._step)
self._parallel_step = jax.jit(self._parallel_step)

# by default, use the first device
self.devices = jax.local_devices()[:1]
Expand Down Expand Up @@ -292,12 +301,19 @@ def setup(self, key):

def step(self, state):
if self.auto_exec_callbacks and state._callbacks:
warnings.warn(
"`step` is called with a state that has leftover callbacks."
"Did you forget to call `execute_callbacks`?"
)
_leftover_callbacks_warning("step")

state = self._step(state)

if self.auto_exec_callbacks:
state = state.execute_callbacks(state)
return state

def parallel_step(self, state):
if self.auto_exec_callbacks and state._callbacks:
_leftover_callbacks_warning("parallel_step")

state = self._step(self, state)
state = self._parallel_step(state)

if self.auto_exec_callbacks:
state = state.execute_callbacks(state)
Expand Down

0 comments on commit ce5e6cd

Please sign in to comment.