diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py index 6c83326d..41332aaf 100644 --- a/burr/core/parallelism.py +++ b/burr/core/parallelism.py @@ -19,7 +19,15 @@ from burr.common import async_utils from burr.common.async_utils import SyncOrAsyncGenerator, SyncOrAsyncGeneratorOrItemOrList -from burr.core import Action, Application, ApplicationBuilder, ApplicationContext, Graph, State +from burr.core import ( + Action, + Application, + ApplicationBuilder, + ApplicationContext, + AsyncApplicationBuilder, + Graph, + State, +) from burr.core.action import SingleStepAction from burr.core.application import ApplicationIdentifiers from burr.core.graph import GraphBuilder @@ -70,7 +78,7 @@ def create(from_: SubgraphType) -> "RunnableGraph": @dataclasses.dataclass class SubGraphTask: - """Task to run a subgraph. Has runtime-spefici information, like inputs, state, and + """Task to run a subgraph. Has runtime-specific information, like inputs, state, and the application ID. This is the lower-level component -- the user will only directly interact with this if they use the TaskBasedParallelAction interface, which produces a generator of these. """ @@ -133,8 +141,45 @@ def run( ) return state + async def _create_async_app(self, parent_context: ApplicationIdentifiers) -> Application: + builder = ( + AsyncApplicationBuilder() + .with_graph(self.graph.graph) + .with_spawning_parent( + app_id=parent_context.app_id, + sequence_id=parent_context.sequence_id, + partition_key=parent_context.partition_key, + ) + # TODO -- handle persistence... + .with_identifiers( + app_id=self.application_id, + partition_key=parent_context.partition_key, # cascade the partition key + ) + ) + if self.tracker is not None: + builder = builder.with_tracker(self.tracker) # TODO -- move this into the adapter + # In this case we want to persist the state for the app + if self.state_persister is not None: + builder = builder.with_state_persister(self.state_persister) + # In this case we want to initialize from it + # We're going to use default settings (initialize from the latest + # TODO -- consider if there's a case in which we want to initialize + # in a custom manner + # if state_initializer is not None and self.cascade_state_initializer: + if self.state_initializer is not None: + builder = builder.initialize_from( + self.state_initializer, + default_state=self.state.get_all(), # TODO _- ensure that any hidden variables aren't used... + default_entrypoint=self.graph.entrypoint, + resume_at_next_action=True, + ) + else: + builder = builder.with_entrypoint(self.graph.entrypoint).with_state(self.state) + + return await builder.build() + async def arun(self, parent_context: ApplicationContext): - app = self._create_app(parent_context) + app = await self._create_async_app(parent_context) action, result, state = await app.arun( halt_after=self.graph.halt_after, inputs={key: value for key, value in self.inputs.items() if not key.startswith("__")},