Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async persistance #488

Merged
merged 3 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 254 additions & 48 deletions burr/core/application.py

Large diffs are not rendered by default.

27 changes: 20 additions & 7 deletions burr/core/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import hashlib
import inspect
import logging
from typing import (
Any,
AsyncGenerator,
Expand All @@ -19,7 +20,7 @@

from burr.common import async_utils
from burr.common.async_utils import SyncOrAsyncGenerator, SyncOrAsyncGeneratorOrItemOrList
from burr.core import Action, Application, ApplicationBuilder, ApplicationContext, Graph, State
from burr.core import Action, ApplicationBuilder, ApplicationContext, Graph, State
from burr.core.action import SingleStepAction
from burr.core.application import ApplicationIdentifiers
from burr.core.graph import GraphBuilder
Expand All @@ -28,6 +29,7 @@
from burr.tracking.base import TrackingClient

SubgraphType = Union[Action, Callable, "RunnableGraph"]
logger = logging.getLogger(__name__)


@dataclasses.dataclass
Expand Down Expand Up @@ -70,7 +72,7 @@ def create(from_: SubgraphType) -> "RunnableGraph":

@dataclasses.dataclass
class SubGraphTask:
"""Task to run a subgraph. Has runtime-spefici information, like inputs, state, and
"""Task to run a subgraph. Has runtime-specific information, like inputs, state, and
the application ID. This is the lower-level component -- the user will only directly interact
with this if they use the TaskBasedParallelAction interface, which produces a generator of these.
"""
Expand All @@ -84,7 +86,7 @@ class SubGraphTask:
state_persister: Optional[BaseStateSaver] = None
state_initializer: Optional[BaseStateLoader] = None

def _create_app(self, parent_context: ApplicationIdentifiers) -> Application:
def _create_app_builder(self, parent_context: ApplicationIdentifiers) -> ApplicationBuilder:
builder = (
ApplicationBuilder()
.with_graph(self.graph.graph)
Expand All @@ -101,6 +103,7 @@ def _create_app(self, parent_context: ApplicationIdentifiers) -> Application:
)
if self.tracker is not None:
builder = builder.with_tracker(self.tracker) # TODO -- move this into the adapter

# In this case we want to persist the state for the app
if self.state_persister is not None:
builder = builder.with_state_persister(self.state_persister)
Expand All @@ -119,22 +122,32 @@ def _create_app(self, parent_context: ApplicationIdentifiers) -> Application:
else:
builder = builder.with_entrypoint(self.graph.entrypoint).with_state(self.state)

return builder.build()
return builder

def run(
self,
parent_context: ApplicationContext,
) -> State:
"""Runs the task -- this simply executes it b y instantiating a sub-application"""
app = self._create_app(parent_context)
"""Runs the task -- this simply executes it by instantiating a sub-application"""
app = self._create_app_builder(parent_context).build()
action, result, state = app.run(
halt_after=self.graph.halt_after,
inputs={key: value for key, value in self.inputs.items() if not key.startswith("__")},
)
return state

async def arun(self, parent_context: ApplicationContext):
app = self._create_app(parent_context)
# Here for backwards compatibility, not ideal
if (self.state_initializer is not None and not self.state_initializer.is_async()) or (
self.state_persister is not None and not self.state_persister.is_async()
):
logger.warning(
"You are using sync persisters for an async application which is not optimal. "
"Consider switching to an async persister implementation. We will make this an error soon."
)
app = self._create_app_builder(parent_context).build()
else:
app = await self._create_app_builder(parent_context).abuild()
action, result, state = await app.arun(
halt_after=self.graph.halt_after,
inputs={key: value for key, value in self.inputs.items() if not key.startswith("__")},
Expand Down
Loading
Loading