Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ray integration + Parallelism blog post #483

Merged
merged 7 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions burr/core/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,16 @@ def __del__(self):
# closes connection at end when things are being shutdown.
self.connection.close()

def __getstate__(self):
return {key: value for key, value in self.__dict__.items() if key != "connection"}

def __setstate__(self, state):
for key, value in state.items():
setattr(self, key, value)
self.connection = sqlite3.connect(
self.db_path, **self._connect_kwargs if self._connect_kwargs is not None else {}
)


class InMemoryPersister(BaseStatePersister):
"""In-memory persister for testing purposes. This is not recommended for production use."""
Expand Down
38 changes: 38 additions & 0 deletions burr/integrations/ray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import concurrent.futures

import ray


class RayExecutor(concurrent.futures.Executor):
"""Ray parallel executor -- implementation of concurrent.futures.Executor.
Currently experimental"""

def __init__(self, shutdown_on_end: bool = False):
"""Creates a Ray executor -- remember to call ray.init() before running anything!"""
self.shutdown_on_end = shutdown_on_end

def submit(self, fn, *args, **kwargs):
"""Submits to ray -- creates a python future by calling ray.remote

:param fn: Function to submit
:param args: Args for the fn
:param kwargs: Kwargs for the fn
:return: The future for the fn
"""
if not ray.is_initialized():
raise RuntimeError("Ray is not initialized. Call ray.init() before running anything!")
ray_fn = ray.remote(fn)
object_ref = ray_fn.remote(*args, **kwargs)
future = object_ref.future()
elijahbenizzy marked this conversation as resolved.
Show resolved Hide resolved
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ray's ObjectRef does not have a future() method. Use ray.get(object_ref) to retrieve the result instead.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

object_ref does not have a future() method. Use ray.get(object_ref) to retrieve the result instead.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

object_ref does not have a future method. Use ray.get(object_ref) to retrieve the result instead.


return future

def shutdown(self, wait=True, **kwargs):
"""Shuts down the executor by shutting down ray

:param wait: Whether to wait -- required for hte API but not respected (yet)
:param kwargs: Keyword arguments -- not used yet
"""
if self.shutdown_on_end:
if ray.is_initialized():
ray.shutdown()
15 changes: 14 additions & 1 deletion burr/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def load_state(
sequence_id: int = -1,
storage_dir: str = DEFAULT_STORAGE_DIR,
) -> tuple[dict, str]:
"""THis is deprecated and will be removed when we migrate over demos. Do not use! Instead use
"""This is deprecated and will be removed when we migrate over demos. Do not use! Instead use
the persistence API :py:class:`initialize_from <burr.core.application.ApplicationBuilder.initialize_from>`
to load state.

Expand Down Expand Up @@ -360,6 +360,18 @@ def _ensure_dir_structure(self, app_id: str):
logger.info(f"Creating application directory: {application_path}")
os.makedirs(application_path)

def __setstate__(self, state):
self.__dict__.update(state)

def __getstate__(self):
out = {
key: value for key, value in self.__dict__.items() if key != "f"
} # the file we don't want to serialize
# Note that this will only work if we also call post_application_create
# For now that's OK as that's the only reason we'll add it -- if we want more distribution later we'll have to serialize the file
out["f"] = None
return out

def post_application_create(
self,
*,
Expand All @@ -378,6 +390,7 @@ def post_application_create(
encoding="utf-8",
errors="replace",
)

graph_path = os.path.join(self.storage_dir, app_id, self.GRAPH_FILENAME)
if os.path.exists(graph_path):
logger.info(f"Graph already exists at {graph_path}. Not overwriting.")
Expand Down
1 change: 1 addition & 0 deletions docs/reference/integrations/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Integrations -- we will be adding more
langchain
pydantic
haystack
ray
8 changes: 8 additions & 0 deletions docs/reference/integrations/ray.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
===
Ray
===

The Burr Ray integration allows you to run :ref:`parallel sub-applications <parallelism>` on `Ray <ray.io>`_.

.. autoclass:: burr.integrations.ray.RayExecutor
:members:
2 changes: 1 addition & 1 deletion docs/reference/persister.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Supported Implementations

Currently we support the following, although we highly recommend you contribute your own! We will be adding more shortly.

.. autoclass:: burr.core.persistence.SQLLitePersister
.. autoclass:: burr.core.persistence.SQLitePersister
:members:

.. automethod:: __init__
Expand Down
7 changes: 7 additions & 0 deletions examples/ray/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Parallelism on Burr

This is supporting code for two blog posts:
1. [Parallel Multi Agent Workflows with Burr](https://blog.dagworks.io/p/93838d1f-52b5-4a72-999f-9cab9733d4fe)
2. [Parallel, Fault-Tolerant Agents with Burr/Ray](https://blog.dagworks.io/p/5baf1077-2490-44bc-afff-fcdafe18e819)

You can find basic code in [application.py](application.py) and run it in [notebook.ipynb](notebook.ipynb). Read the blog posts to get a sense for the motivation/design behind this.
Empty file added examples/ray/__init__.py
Empty file.
206 changes: 206 additions & 0 deletions examples/ray/application.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from typing import Any, Dict, List, Optional, Tuple

import openai
import ray

from burr.common.async_utils import SyncOrAsyncGenerator
from burr.core import Application, ApplicationBuilder, Condition, GraphBuilder, State, action
from burr.core.application import ApplicationContext
from burr.core.parallelism import MapStates, RunnableGraph, SubgraphType
from burr.core.persistence import SQLitePersister
from burr.integrations.ray import RayExecutor


# full agent
def _query_llm(prompt: str) -> str:
"""Simple wrapper around the OpenAI API."""
client = openai.Client()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a new OpenAI client for each call can be inefficient. Consider initializing the client once and reusing it across function calls.

return (
client.chat.completions.create(
model="gpt-4o",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model="gpt-4o" should be model="gpt-4".

messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": prompt},
],
)
.choices[0]
.message.content
)


@action(
reads=["feedback", "current_draft", "poem_type", "poem_subject"],
writes=["current_draft", "draft_history", "num_drafts"],
)
def write(state: State) -> Tuple[dict, State]:
"""Writes a draft of a poem."""
poem_subject = state["poem_subject"]
poem_type = state["poem_type"]
current_draft = state.get("current_draft")
feedback = state.get("feedback")

parts = [
f'You are an AI poet. Create a {poem_type} poem on the following subject: "{poem_subject}". '
"It is absolutely imperative that you respond with only the poem and no other text."
]

if current_draft:
parts.append(f'Here is the current draft of the poem: "{current_draft}".')

if feedback:
parts.append(f'Please incorporate the following feedback: "{feedback}".')

parts.append(
f"Ensure the poem is creative, adheres to the style of a {poem_type}, and improves upon the previous draft."
)

prompt = "\n".join(parts)

draft = _query_llm(prompt)

return {"draft": draft}, state.update(
current_draft=draft,
draft_history=state.get("draft_history", []) + [draft],
).increment(num_drafts=1)


@action(reads=["current_draft", "poem_type", "poem_subject"], writes=["feedback"])
def edit(state: State) -> Tuple[dict, State]:
"""Edits a draft of a poem, providing feedback"""
poem_subject = state["poem_subject"]
poem_type = state["poem_type"]
current_draft = state["current_draft"]

prompt = f"""
You are an AI poetry critic. Review the following {poem_type} poem based on the subject: "{poem_subject}".
Here is the current draft of the poem: "{current_draft}".
Provide detailed feedback to improve the poem. If the poem is already excellent and needs no changes, simply respond with an empty string.
"""
feedback = _query_llm(prompt)

return {"feedback": feedback}, state.update(feedback=feedback)


@action(reads=["current_draft"], writes=["final_draft"])
def final_draft(state: State) -> Tuple[dict, State]:
return {"final_draft": state["current_draft"]}, state.update(final_draft=state["current_draft"])


# full agent
@action(
reads=[],
writes=[
"max_drafts",
"poem_types",
"poem_subject",
],
)
def user_input(state: State, max_drafts: int, poem_types: List[str], poem_subject: str) -> State:
"""Collects user input for the poem generation process."""
return state.update(max_drafts=max_drafts, poem_types=poem_types, poem_subject=poem_subject)


class GenerateAllPoems(MapStates):
def states(
self, state: State, context: ApplicationContext, inputs: Dict[str, Any]
) -> SyncOrAsyncGenerator[State]:
for poem_type in state["poem_types"]:
yield state.update(current_draft=None, poem_type=poem_type, feedback=[], num_drafts=0)

def action(self, state: State, inputs: Dict[str, Any]) -> SubgraphType:
graph = (
GraphBuilder()
.with_actions(
edit,
write,
final_draft,
)
.with_transitions(
("write", "edit", Condition.expr(f"num_drafts < {state['max_drafts']}")),
("write", "final_draft"),
("edit", "final_draft", Condition.expr("len(feedback) == 0")),
("edit", "write"),
)
).build()
return RunnableGraph(graph=graph, entrypoint="write", halt_after=["final_draft"])

def reduce(self, state: State, results: SyncOrAsyncGenerator[State]) -> State:
proposals = []
for output_state in results:
proposals.append(output_state["final_draft"])
return state.append(proposals=proposals)

@property
def writes(self) -> list[str]:
return ["proposals"]

@property
def reads(self) -> list[str]:
return ["poem_types", "poem_subject", "max_drafts"]


@action(reads=["proposals", "poem_types"], writes=["final_results"])
def final_results(state: State) -> Tuple[dict, State]:
# joins them into a string
proposals = state["proposals"]
final_results = "\n\n".join(
[f"{poem_type}:\n{proposal}" for poem_type, proposal in zip(state["poem_types"], proposals)]
)
return {"final_results": final_results}, state.update(final_results=final_results)


def application_multithreaded() -> Application:
app = (
ApplicationBuilder()
.with_actions(user_input, final_results, generate_all_poems=GenerateAllPoems())
.with_transitions(
("user_input", "generate_all_poems"),
("generate_all_poems", "final_results"),
)
.with_tracker(project="demo:parallel_agents")
.with_entrypoint("user_input")
.build()
)
return app


def application(app_id: Optional[str] = None) -> Application:
persister = SQLitePersister(db_path="./db")
persister.initialize()
app = (
ApplicationBuilder()
.with_actions(user_input, final_results, generate_all_poems=GenerateAllPoems())
.with_transitions(
("user_input", "generate_all_poems"),
("generate_all_poems", "final_results"),
)
.with_tracker(project="demo:parallel_agents_fault_tolerance")
.with_parallel_executor(RayExecutor)
.with_state_persister(persister)
.initialize_from(
persister, resume_at_next_action=True, default_state={}, default_entrypoint="user_input"
)
.with_identifiers(app_id=app_id)
.build()
)
return app


if __name__ == "__main__":
ray.init()
app = application()
app_id = app.uid
act, _, state = app.run(
halt_after=["final_results"],
inputs={
"max_drafts": 2,
"poem_types": [
"sonnet",
"limerick",
"haiku",
"acrostic",
],
"poem_subject": "state machines",
},
)
print(state)
Loading
Loading