Skip to content

Commit

Permalink
Add async application builder -- tested
Browse files Browse the repository at this point in the history
  • Loading branch information
jernejfrank committed Dec 27, 2024
1 parent 64f9b66 commit 05c6d09
Show file tree
Hide file tree
Showing 4 changed files with 411 additions and 10 deletions.
2 changes: 2 additions & 0 deletions burr/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ApplicationBuilder,
ApplicationContext,
ApplicationGraph,
AsyncApplicationBuilder,
)
from burr.core.graph import Graph, GraphBuilder
from burr.core.state import State
Expand All @@ -13,6 +14,7 @@
"Action",
"Application",
"ApplicationBuilder",
"AsyncApplicationBuilder",
"ApplicationGraph",
"ApplicationContext",
"Condition",
Expand Down
222 changes: 214 additions & 8 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
StreamingResultContainer,
)
from burr.core.graph import Graph, GraphBuilder
from burr.core.persistence import BaseStateLoader, BaseStateSaver
from burr.core.persistence import (
AsyncBaseStateLoader,
AsyncBaseStateSaver,
BaseStateLoader,
BaseStateSaver,
)
from burr.core.state import State
from burr.core.typing import ActionSchema, DictBasedTypingSystem, TypingSystem
from burr.core.validation import BASE_ERROR_MESSAGE
Expand Down Expand Up @@ -87,7 +92,9 @@ def _raise_fn_return_validation_error(output: Any, action_name: str):


def _adjust_single_step_output(
output: Union[State, Tuple[dict, State]], action_name: str, action_schema: ActionSchema
output: Union[State, Tuple[dict, State]],
action_name: str,
action_schema: ActionSchema,
):
"""Adjusts the output of a single step action to be a tuple of (result, state) or just state"""

Expand Down Expand Up @@ -1037,7 +1044,10 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True
)
else:
result = await _arun_function(
next_action, self._state, inputs=action_inputs, name=next_action.name
next_action,
self._state,
inputs=action_inputs,
name=next_action.name,
)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
new_state = self._update_internal_state_value(new_state, next_action)
Expand Down Expand Up @@ -1174,6 +1184,22 @@ def iterate(
:return: Each iteration returns the result of running `step`. This generator also returns a tuple of
[action, result, current state]
"""
# This is a gentle warning for existing users
if self._adapter_set.async_hooks:
logger.warning(
"There are asynchronous hooks present in the application that will be ignored. "
"Please use .aiterate or .arun methods to have them executed. "
f"The application has following asynchronous hooks: {self._adapter_set.async_hooks} "
)

# Seems fair to raise this if everything is async but the app execution
if self._adapter_set.async_hooks and isinstance(self._builder, AsyncApplicationBuilder):
raise ValueError(
"The application was build using the AsyncApplicationBuilder and has "
"async hooks present, which need to be executed in an asynchronous run. "
"Please use the .aiterate or .arun methods to run the application."
)

halt_before, halt_after, inputs = self._clean_iterate_params(
halt_before, halt_after, inputs
)
Expand Down Expand Up @@ -1466,8 +1492,9 @@ def callback(
sequence_id=self.sequence_id,
exception=None,
)
out = action, StreamingResultContainer.pass_through(
results=result, final_state=state
out = (
action,
StreamingResultContainer.pass_through(results=result, final_state=state),
)
call_execute_method_wrapper.call_post(self, None)
return out
Expand Down Expand Up @@ -2150,7 +2177,9 @@ def with_entrypoint(self, action: str) -> "ApplicationBuilder[StateType]":
return self

def with_actions(
self, *action_list: Union[Action, Callable], **action_dict: Union[Action, Callable]
self,
*action_list: Union[Action, Callable],
**action_dict: Union[Action, Callable],
) -> "ApplicationBuilder[StateType]":
"""Adds an action to the application. The actions are granted names (using the with_name)
method post-adding, using the kw argument. If it already has a name (or you wish to use the function name, raw, and
Expand All @@ -2168,7 +2197,8 @@ def with_actions(
def with_transitions(
self,
*transitions: Union[
Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition]
Tuple[Union[str, list[str]], str],
Tuple[Union[str, list[str]], str, Condition],
],
) -> "ApplicationBuilder[StateType]":
"""Adds transitions to the application. Transitions are specified as tuples of either:
Expand Down Expand Up @@ -2249,7 +2279,7 @@ def with_tracker(

def initialize_from(
self,
initializer: BaseStateLoader,
initializer: Union[BaseStateLoader, AsyncBaseStateLoader],
resume_at_next_action: bool,
default_state: dict,
default_entrypoint: str,
Expand Down Expand Up @@ -2484,3 +2514,179 @@ def build(self) -> Application[StateType]:
state_persister=self.state_persister,
state_initializer=self.state_initializer,
)


class AsyncApplicationBuilder(ApplicationBuilder):
def __init__(self):
super().__init__()

def is_async(self) -> bool:
return True

def with_state_persister(
self,
persister: Union[AsyncBaseStateSaver, LifecycleAdapter],
on_every: str = "step",
) -> "ApplicationBuilder[StateType]":
"""Adds a state persister to the application. This is a way to persist state out to a database, file, etc...
at the specified interval. This is one of two options:
1. [normal mode] A BaseStateSaver object -- this is a utility class that makes it easy to save/load
2. [power-user-mode] A lifecycle adapter -- this is a custom class that you use to save state.
The framework will wrap the BaseStateSaver object in a PersisterHook, which is a post-run.
:param persister: The persister to add
:param on_every: The interval to persist state. Currently only "step" is supported.
:return: The application builder for future chaining.
"""
if on_every != "step":
raise ValueError(f"on_every {on_every} not supported")

self.state_persister = persister # track for later
return self

async def __with_async_state_persister(self):
"""This is the synchronous with_state_persister turned asynchronous.
Moved here to be able to chain coroutines like we chain methods in sync ApplicationBuilder.
"""
if not isinstance(self.state_persister, persistence.AsyncBaseStateSaver):
self.lifecycle_adapters.append(self.state_persister)
else:
# Check if 'is_initialized' exists and is False; raise RuntimeError, else continue if not implemented
try:
if not await self.state_persister.is_initialized():
raise RuntimeError(
"RuntimeError: Uninitialized persister. Make sure to call .initialize() before passing it to "
"the ApplicationBuilder."
)
except NotImplementedError:
pass
self.lifecycle_adapters.append(persistence.PersisterHookAsync(self.state_persister))

async def _load_from_persister(self):
"""Loads from the set persister and into this current object.
Mutates:
- self.state
- self.sequence_id
- maybe self.start
"""
if self.fork_from_app_id is not None:
if self.app_id == self.fork_from_app_id:
raise ValueError(
BASE_ERROR_MESSAGE + "Cannot fork and save to the same app_id. "
"Please update the app_id passed in via with_identifiers(), "
"or don't pass in a fork_from_app_id value to `initialize_from()`."
)
_partition_key = self.fork_from_partition_key
_app_id = self.fork_from_app_id
_sequence_id = self.fork_from_sequence_id
else:
# only use the with_identifier values if we're not forking from a previous app
_partition_key = self.partition_key
_app_id = self.app_id
_sequence_id = self.sequence_id
# load state from persister
load_result = await self.state_initializer.load(_partition_key, _app_id, _sequence_id)
if load_result is None:
if self.fork_from_app_id is not None:
logger.warning(
f"{self.state_initializer.__class__.__name__} returned None while trying to fork from: "
f"partition_key:{_partition_key}, app_id:{_app_id}, "
f"sequence_id:{_sequence_id}. "
"You explicitly requested to fork from a prior application run, but it does not exist. "
"Defaulting to state defaults instead."
)
# there was nothing to load -- use default state
self.state = self.state.update(**self.default_state)
self.sequence_id = None # has to start at None
else:
self.loaded_from_fork = True
if load_result["state"] is None:
raise ValueError(
BASE_ERROR_MESSAGE
+ f"Error: {self.state_initializer.__class__.__name__} returned {load_result} for "
f"partition_key:{_partition_key}, app_id:{_app_id}, "
f"sequence_id:{_sequence_id}, "
"but value for state was None! This is not allowed. Please return just None in this case, "
"or double check that persisted state can never be a None value."
)
# TODO: capture parent app ID relationship & wire it through
# there was something
last_position = load_result["position"]
self.state = load_result["state"]
self.sequence_id = load_result["sequence_id"]
status = load_result["status"]
if self.resume_at_next_action:
# if we're supposed to resume where we saved from
if status == "completed":
# completed means we set prior step to current to go to next action
self.state = self.state.update(**{PRIOR_STEP: last_position})
else:
# else we failed we just start at that node
self.start = last_position
self.reset_to_entrypoint()
else:
# self.start is already set to the default. We don't need to do anything.
pass

@telemetry.capture_function_usage
async def build(self) -> Application[StateType]:
"""Builds the application.
This function is a bit messy as we iron out the exact logic and rigor we want around things.
:return: The application object
"""

# If we make state persister async we cannot method chain the builder. This delays the persister
# init so that we can coroutine chanin within the build method
if self.state_persister:
await self.__with_async_state_persister()

_validate_app_id(self.app_id)
if self.state is None:
self.state = State()
if self.state_initializer:
# sets state, sequence_id, and maybe start
await self._load_from_persister()
graph = self._get_built_graph()
_validate_start(self.start, {action.name for action in graph.actions})
typing_system: TypingSystem[StateType] = (
self.typing_system if self.typing_system is not None else DictBasedTypingSystem()
) # type: ignore
self.state = self.state.with_typing_system(typing_system=typing_system)
return Application(
graph=graph,
state=self.state,
uid=self.app_id,
partition_key=self.partition_key,
sequence_id=self.sequence_id,
entrypoint=self.start,
adapter_set=LifecycleAdapterSet(*self.lifecycle_adapters),
builder=self,
fork_parent_pointer=(
burr_types.ParentPointer(
app_id=self.fork_from_app_id,
partition_key=self.fork_from_partition_key,
sequence_id=self.fork_from_sequence_id,
)
if self.loaded_from_fork
else None
),
tracker=self.tracker,
spawning_parent_pointer=(
burr_types.ParentPointer(
app_id=self.spawn_from_app_id,
partition_key=self.spawn_from_partition_key,
sequence_id=self.spawn_from_sequence_id,
)
if self.spawn_from_app_id is not None
else None
),
parallel_executor_factory=self.parallel_executor_factory,
state_persister=self.state_persister,
state_initializer=self.state_initializer,
)
3 changes: 3 additions & 0 deletions burr/core/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,11 @@ class AsyncDevNullPersister(AsyncBaseStatePersister):
async def load(
self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs
) -> Optional[PersistedStateData]:
# print("I loaded something.")
return None

async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
# print("I checked app ids.")
return []

async def save(
Expand All @@ -282,6 +284,7 @@ async def save(
status: Literal["completed", "failed"],
**kwargs,
):
# print("I saved something.")
return


Expand Down
Loading

0 comments on commit 05c6d09

Please sign in to comment.