Skip to content

Commit

Permalink
Wires through persistence to spawned children apps
Browse files Browse the repository at this point in the history
This enables us to use persisters from the parent in spawned apps.

It roughly works as follows:
1. We add a copy() method to the persisters
2. We pass hte persisters into the context input to the child
3. The child then clones the persisters/trackers
4. The child has a setting that allows for either cascading
or overwriting persisters/trackers/whatnot (cascade is the default).
5. We use the same instance of persister/loader *if* the parent is using
   the same instance

Note that this also improves some asynchronous generator stuff -- it's a
bit complicated, but it allows you to define asynchronous *or*
synchronous generators for the task-producing subclasses of parallel
actions. The type is detected at runtime, allowing flexible
implementation.
  • Loading branch information
elijahbenizzy committed Dec 4, 2024
1 parent b2eef9e commit 25c907b
Show file tree
Hide file tree
Showing 9 changed files with 837 additions and 86 deletions.
3 changes: 1 addition & 2 deletions burr/common/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
T = TypeVar("T")

GenType = TypeVar("GenType")
ReturnType = TypeVar("ReturnType")

SyncOrAsyncIterable = Union[AsyncIterable[T], List[T]]
SyncOrAsyncGenerator = Union[Generator[GenType, None, None], AsyncGenerator[GenType, None]]
Expand All @@ -31,7 +30,7 @@ async def arealize(maybe_async_generator: SyncOrAsyncGenerator[GenType]) -> List
"""Realize an async generator or async iterable to a list.
:param maybe_async_generator: async generator or async iterable
:return: list
:return: list of items -- fully realized
"""
if inspect.isasyncgen(maybe_async_generator):
out = [item async for item in maybe_async_generator]
Expand Down
18 changes: 16 additions & 2 deletions burr/common/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import abc
import dataclasses
from typing import Optional
from typing import Any, Optional

try:
from typing import Self
except ImportError:
Self = Any

# This contains commmon types

# This contains common types
# Currently the types are a little closer to the logic than we'd like
# We'll want to break them out into interfaces and put more here eventually
# This will help avoid the ugly if TYPE_CHECKING imports;
Expand All @@ -11,3 +17,11 @@ class ParentPointer:
app_id: str
partition_key: Optional[str]
sequence_id: Optional[int]


class BaseCopyable(abc.ABC):
"""Interface for copying objects. This is used internally."""

@abc.abstractmethod
def copy(self) -> "Self":
pass
48 changes: 32 additions & 16 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,14 @@ class ApplicationGraph(Graph):


@dataclasses.dataclass
class ApplicationContext(AbstractContextManager):
class ApplicationIdentifiers:
app_id: str
partition_key: Optional[str]
sequence_id: Optional[int]


@dataclasses.dataclass
class ApplicationContext(AbstractContextManager, ApplicationIdentifiers):
"""Application context. This is anything your node might need to know about the application.
Often used for recursive tracking.
Expand All @@ -508,11 +515,10 @@ def my_action(state: State, __context: ApplicationContext) -> State:
"""

app_id: str
partition_key: Optional[str]
sequence_id: Optional[int]
tracker: Optional["TrackingClient"]
parallel_executor_factory: Callable[[], Executor]
state_initializer: Optional[BaseStateLoader]
state_persister: Optional[BaseStateSaver]

@staticmethod
def get() -> Optional["ApplicationContext"]:
Expand Down Expand Up @@ -718,6 +724,8 @@ def __init__(
spawning_parent_pointer: Optional[burr_types.ParentPointer] = None,
tracker: Optional["TrackingClient"] = None,
parallel_executor_factory: Optional[Executor] = None,
state_persister: Union[BaseStateSaver, LifecycleAdapter, None] = None,
state_initializer: Union[BaseStateLoader, LifecycleAdapter, None] = None,
):
"""Instantiates an Application. This is an internal API -- use the builder!
Expand Down Expand Up @@ -767,6 +775,8 @@ def __init__(
"__context": self._context_factory,
}
self._spawning_parent_pointer = spawning_parent_pointer
self._state_initializer = state_initializer
self._state_persister = state_persister
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_application_create",
state=self._state,
Expand Down Expand Up @@ -807,6 +817,8 @@ def _context_factory(self, action: Action, sequence_id: int) -> ApplicationConte
partition_key=self._partition_key,
sequence_id=sequence_id,
parallel_executor_factory=self._parallel_executor_factory,
state_initializer=self._state_initializer,
state_persister=self._state_persister,
)

def _step(
Expand Down Expand Up @@ -1937,7 +1949,7 @@ def __init__(self):
self.app_id: str = str(uuid.uuid4())
self.partition_key: Optional[str] = None
self.sequence_id: Optional[int] = None
self.initializer = None
self.state_initializer = None
self.use_entrypoint_from_save_state: Optional[bool] = None
self.default_state: Optional[dict] = None
self.fork_from_app_id: Optional[str] = None
Expand All @@ -1951,7 +1963,8 @@ def __init__(self):
self.graph_builder = None
self.prebuilt_graph = None
self.typing_system = None
self._parallel_executor_factory = None
self.parallel_executor_factory = None
self.state_persister = None

def with_identifiers(
self, app_id: str = None, partition_key: str = None, sequence_id: int = None
Expand Down Expand Up @@ -1996,7 +2009,7 @@ def with_state(
:param kwargs: Key-value pairs to set in the state
:return: The application builder for future chaining.
"""
if self.initializer is not None:
if self.state_initializer is not None:
raise ValueError(
BASE_ERROR_MESSAGE + "You cannot set state if you are loading state"
"the .initialize_from() API. Either allow the persister to set the "
Expand Down Expand Up @@ -2062,14 +2075,14 @@ def with_parallel_executor(self, executor_factory: lambda: Executor):
:param executor:
:return:
"""
if self._parallel_executor_factory is not None:
if self.parallel_executor_factory is not None:
raise ValueError(
BASE_ERROR_MESSAGE
+ "You have already set an executor. You cannot set multiple executors. Current executor is:"
f"{self._parallel_executor_factory}"
f"{self.parallel_executor_factory}"
)

self._parallel_executor_factory = executor_factory
self.parallel_executor_factory = executor_factory
return self

def _ensure_no_prebuilt_graph(self):
Expand Down Expand Up @@ -2238,7 +2251,7 @@ def initialize_from(
+ "If you set fork_from_partition_key or fork_from_sequence_id, you must also set fork_from_app_id. "
"See .initialize_from() documentation."
)
self.initializer = initializer
self.state_initializer = initializer
self.resume_at_next_action = resume_at_next_action
self.default_state = default_state
self.start = default_entrypoint
Expand Down Expand Up @@ -2278,6 +2291,7 @@ def with_state_persister(
except NotImplementedError:
pass
self.lifecycle_adapters.append(persistence.PersisterHook(persister))
self.state_persister = persister # track for later
return self

def with_spawning_parent(
Expand Down Expand Up @@ -2326,11 +2340,11 @@ def _load_from_persister(self):
_app_id = self.app_id
_sequence_id = self.sequence_id
# load state from persister
load_result = self.initializer.load(_partition_key, _app_id, _sequence_id)
load_result = self.state_initializer.load(_partition_key, _app_id, _sequence_id)
if load_result is None:
if self.fork_from_app_id is not None:
logger.warning(
f"{self.initializer.__class__.__name__} returned None while trying to fork from: "
f"{self.state_initializer.__class__.__name__} returned None while trying to fork from: "
f"partition_key:{_partition_key}, app_id:{_app_id}, "
f"sequence_id:{_sequence_id}. "
"You explicitly requested to fork from a prior application run, but it does not exist. "
Expand All @@ -2344,7 +2358,7 @@ def _load_from_persister(self):
if load_result["state"] is None:
raise ValueError(
BASE_ERROR_MESSAGE
+ f"Error: {self.initializer.__class__.__name__} returned {load_result} for "
+ f"Error: {self.state_initializer.__class__.__name__} returned {load_result} for "
f"partition_key:{_partition_key}, app_id:{_app_id}, "
f"sequence_id:{_sequence_id}, "
"but value for state was None! This is not allowed. Please return just None in this case, "
Expand Down Expand Up @@ -2395,7 +2409,7 @@ def build(self) -> Application[StateType]:
_validate_app_id(self.app_id)
if self.state is None:
self.state = State()
if self.initializer:
if self.state_initializer:
# sets state, sequence_id, and maybe start
self._load_from_persister()
graph = self._get_built_graph()
Expand Down Expand Up @@ -2432,5 +2446,7 @@ def build(self) -> Application[StateType]:
if self.spawn_from_app_id is not None
else None
),
parallel_executor_factory=self._parallel_executor_factory,
parallel_executor_factory=self.parallel_executor_factory,
state_persister=self.state_persister,
state_initializer=self.state_initializer,
)
Loading

0 comments on commit 25c907b

Please sign in to comment.