Skip to content

Commit

Permalink
Add async builder to SubGraphTask for parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
jernejfrank committed Dec 27, 2024
1 parent 05c6d09 commit 9e233c8
Showing 1 changed file with 48 additions and 3 deletions.
51 changes: 48 additions & 3 deletions burr/core/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@

from burr.common import async_utils
from burr.common.async_utils import SyncOrAsyncGenerator, SyncOrAsyncGeneratorOrItemOrList
from burr.core import Action, Application, ApplicationBuilder, ApplicationContext, Graph, State
from burr.core import (
Action,
Application,
ApplicationBuilder,
ApplicationContext,
AsyncApplicationBuilder,
Graph,
State,
)
from burr.core.action import SingleStepAction
from burr.core.application import ApplicationIdentifiers
from burr.core.graph import GraphBuilder
Expand Down Expand Up @@ -70,7 +78,7 @@ def create(from_: SubgraphType) -> "RunnableGraph":

@dataclasses.dataclass
class SubGraphTask:
"""Task to run a subgraph. Has runtime-spefici information, like inputs, state, and
"""Task to run a subgraph. Has runtime-specific information, like inputs, state, and
the application ID. This is the lower-level component -- the user will only directly interact
with this if they use the TaskBasedParallelAction interface, which produces a generator of these.
"""
Expand Down Expand Up @@ -133,8 +141,45 @@ def run(
)
return state

async def _create_async_app(self, parent_context: ApplicationIdentifiers) -> Application:
builder = (
AsyncApplicationBuilder()
.with_graph(self.graph.graph)
.with_spawning_parent(
app_id=parent_context.app_id,
sequence_id=parent_context.sequence_id,
partition_key=parent_context.partition_key,
)
# TODO -- handle persistence...
.with_identifiers(
app_id=self.application_id,
partition_key=parent_context.partition_key, # cascade the partition key
)
)
if self.tracker is not None:
builder = builder.with_tracker(self.tracker) # TODO -- move this into the adapter
# In this case we want to persist the state for the app
if self.state_persister is not None:
builder = builder.with_state_persister(self.state_persister)
# In this case we want to initialize from it
# We're going to use default settings (initialize from the latest
# TODO -- consider if there's a case in which we want to initialize
# in a custom manner
# if state_initializer is not None and self.cascade_state_initializer:
if self.state_initializer is not None:
builder = builder.initialize_from(
self.state_initializer,
default_state=self.state.get_all(), # TODO _- ensure that any hidden variables aren't used...
default_entrypoint=self.graph.entrypoint,
resume_at_next_action=True,
)
else:
builder = builder.with_entrypoint(self.graph.entrypoint).with_state(self.state)

return await builder.build()

async def arun(self, parent_context: ApplicationContext):
app = self._create_app(parent_context)
app = await self._create_async_app(parent_context)
action, result, state = await app.arun(
halt_after=self.graph.halt_after,
inputs={key: value for key, value in self.inputs.items() if not key.startswith("__")},
Expand Down

0 comments on commit 9e233c8

Please sign in to comment.