Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Dec 23, 2024
1 parent ed2f26b commit c71f34c
Show file tree
Hide file tree
Showing 7 changed files with 538 additions and 38 deletions.
10 changes: 10 additions & 0 deletions burr/core/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,16 @@ def __del__(self):
# closes connection at end when things are being shutdown.
self.connection.close()

def __getstate__(self):
return {key: value for key, value in self.__dict__.items() if key != "connection"}

def __setstate__(self, state):
for key, value in state.items():
setattr(self, key, value)
self.connection = sqlite3.connect(
self.db_path, **self._connect_kwargs if self._connect_kwargs is not None else {}
)


class InMemoryPersister(BaseStatePersister):
"""In-memory persister for testing purposes. This is not recommended for production use."""
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/persister.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Supported Implementations

Currently we support the following, although we highly recommend you contribute your own! We will be adding more shortly.

.. autoclass:: burr.core.persistence.SQLLitePersister
.. autoclass:: burr.core.persistence.SQLitePersister
:members:

.. automethod:: __init__
Expand Down
62 changes: 34 additions & 28 deletions examples/ray/application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import openai
import ray
Expand All @@ -7,6 +7,7 @@
from burr.core import Application, ApplicationBuilder, Condition, GraphBuilder, State, action
from burr.core.application import ApplicationContext
from burr.core.parallelism import MapStates, RunnableGraph, SubgraphType
from burr.core.persistence import SQLitePersister
from burr.integrations.ray import RayExecutor


Expand Down Expand Up @@ -75,7 +76,6 @@ def edit(state: State) -> Tuple[dict, State]:
Here is the current draft of the poem: "{current_draft}".
Provide detailed feedback to improve the poem. If the poem is already excellent and needs no changes, simply respond with an empty string.
"""

feedback = _query_llm(prompt)

return {"feedback": feedback}, state.update(feedback=feedback)
Expand All @@ -95,23 +95,17 @@ def final_draft(state: State) -> Tuple[dict, State]:
"poem_subject",
],
)
def user_input(
state: State, max_drafts: int, poem_types: List[str], poem_subject: str
) -> Tuple[dict, State]:
def user_input(state: State, max_drafts: int, poem_types: List[str], poem_subject: str) -> State:
"""Collects user input for the poem generation process."""
return {
"max_drafts": max_drafts,
"poem_types": poem_types,
"poem_subject": poem_subject,
}, state.update(max_drafts=max_drafts, poem_types=poem_types, poem_subject=poem_subject)
return state.update(max_drafts=max_drafts, poem_types=poem_types, poem_subject=poem_subject)


class GenerateAllPoems(MapStates):
def states(
self, state: State, context: ApplicationContext, inputs: Dict[str, Any]
) -> SyncOrAsyncGenerator[State]:
for poem_type in state["poem_types"]:
yield state.update(current_draft=None, poem_type=poem_type, feedback=None)
yield state.update(current_draft=None, poem_type=poem_type, feedback=[], num_drafts=0)

def action(self, state: State, inputs: Dict[str, Any]) -> SubgraphType:
graph = (
Expand All @@ -122,7 +116,7 @@ def action(self, state: State, inputs: Dict[str, Any]) -> SubgraphType:
final_draft,
)
.with_transitions(
("write", "edit", Condition.expr(f"num_drafts < {inputs['max_drafts']}")),
("write", "edit", Condition.expr(f"num_drafts < {state['max_drafts']}")),
("write", "final_draft"),
("edit", "final_draft", Condition.expr("len(feedback) == 0")),
("edit", "write"),
Expand All @@ -142,11 +136,7 @@ def writes(self) -> list[str]:

@property
def reads(self) -> list[str]:
return ["poem_types", "poem_subject"]

@property
def inputs(self) -> list[str]:
return super().inputs + ["max_drafts"]
return ["poem_types", "poem_subject", "max_drafts"]


@action(reads=["proposals", "poem_types"], writes=["final_results"])
Expand All @@ -159,31 +149,47 @@ def final_results(state: State) -> Tuple[dict, State]:
return {"final_results": final_results}, state.update(final_results=final_results)


def application() -> Application:
ray.init()
return (
def application_multithreaded() -> Application:
app = (
ApplicationBuilder()
.with_actions(user_input, final_results, generate_all_poems=GenerateAllPoems())
.with_transitions(
("user_input", "generate_all_poems"),
("generate_all_poems", "final_results"),
)
.with_tracker(project="test:parallelism_poem_generation_ray")
.with_tracker(project="demo:parallel_agents")
.with_entrypoint("user_input")
.build()
)
return app


def application(app_id: Optional[str] = None) -> Application:
persister = SQLitePersister(db_path="./db")
persister.initialize()
app = (
ApplicationBuilder()
.with_actions(user_input, final_results, generate_all_poems=GenerateAllPoems())
.with_transitions(
("user_input", "generate_all_poems"),
("generate_all_poems", "final_results"),
)
.with_tracker(project="demo:parallel_agents_fault_tolerance")
.with_parallel_executor(RayExecutor)
.with_state_persister(persister)
.initialize_from(
persister, resume_at_next_action=True, default_state={}, default_entrypoint="user_input"
)
.with_identifiers(app_id=app_id)
.build()
)
return app


if __name__ == "__main__":
ray.init()
app = application()
app.visualize(output_file_path="statemachine", format="png")
# _create_sub_application(
# 2,
# app.context,
# "sonnet",
# "state machines",
# ).visualize(output_file_path="statemachine_sub", format="png")
app_id = app.uid
act, _, state = app.run(
halt_after=["final_results"],
inputs={
Expand Down
Loading

0 comments on commit c71f34c

Please sign in to comment.