diff --git a/burr/core/__init__.py b/burr/core/__init__.py index 52747b3c8..1653b14f8 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -3,13 +3,14 @@ from burr.core.state import State __all__ = [ + "action", "Action", + "Application", "ApplicationBuilder", "Condition", - "Result", "default", - "when", "expr", - "Application", + "Result", "State", + "when", ] diff --git a/burr/core/action.py b/burr/core/action.py index d06983151..83f54eef3 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -2,7 +2,8 @@ import ast import copy import inspect -from typing import Callable, List, Tuple, Union +import types +from typing import Any, Callable, List, Protocol, Tuple, TypeVar, Union from burr.core.state import State @@ -169,7 +170,11 @@ class FunctionBasedAction(Action): ACTION_FUNCTION = "action_function" def __init__( - self, fn: Callable[[State], Tuple[dict, State]], reads: List[str], writes: List[str] + self, + fn: Callable[..., Tuple[dict, State]], + reads: List[str], + writes: List[str], + bound_params: dict = None, ): """Instantiates a function-based action with the given function, reads, and writes. The function must take in a state and return a tuple of (result, new_state). @@ -183,13 +188,18 @@ def __init__( self._reads = reads self._writes = writes self._state_created = None + self._bound_params = bound_params if bound_params is not None else {} + + @property + def fn(self) -> Callable: + return self._fn @property def reads(self) -> list[str]: return self._reads def run(self, state: State) -> dict: - result, new_state = self._fn(state) + result, new_state = self._fn(state, **self._bound_params) self._state_created = new_state return result @@ -202,7 +212,23 @@ def update(self, result: dict, state: State) -> State: raise ValueError( "FunctionBasedAction.run must be called before FunctionBasedAction.update" ) - return self._state_created + # TODO -- validate that all the keys are contained -- fix up subset to handle this + # TODO -- validate that we've (a) written only to the write ones (by diffing the read ones), + # and (b) written to no more than the write ones + return self._state_created.subset(*self._writes) + + def with_params(self, **kwargs: Any) -> "FunctionBasedAction": + """Binds parameters to the function. + Note that there is no reason to call this by the user. This *could* + be done at the class level, but given that API allows for constructor parameters + (which do the same thing in a cleaner way), it is best to keep it here for now. + + :param kwargs: + :return: + """ + new_action = copy.copy(self) + new_action._bound_params = {**self._bound_params, **kwargs} + return new_action def _validate_action_function(fn: Callable): @@ -225,7 +251,23 @@ def _validate_action_function(fn: Callable): ) -def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Callable]: +C = TypeVar("C", bound=Callable) # placeholder for any Callable + + +class FunctionRepresentingAction(Protocol[C]): + action_function: FunctionBasedAction + __call__: C + + def bind(self, **kwargs: Any): + ... + + +def bind(self: FunctionRepresentingAction, **kwargs: Any) -> FunctionRepresentingAction: + self.action_function = self.action_function.with_params(**kwargs) + return self + + +def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]: """Decorator to create a function-based action. This is user-facing. Note that, in the future, with typed state, we may not need this for all cases. @@ -235,8 +277,9 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Callable :return: The decorator to assign the function as an action """ - def decorator(fn: Callable) -> Callable: + def decorator(fn) -> FunctionRepresentingAction: setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes)) + setattr(fn, "bind", types.MethodType(bind, fn)) return fn return decorator diff --git a/burr/core/application.py b/burr/core/application.py index cbbc724fc..b93189c54 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -1,7 +1,18 @@ import collections import dataclasses import logging -from typing import Any, AsyncGenerator, Generator, List, Literal, Optional, Set, Tuple, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Generator, + List, + Literal, + Optional, + Set, + Tuple, + Union, +) from burr.core.action import Action, Condition, Function, Reducer, create_action, default from burr.core.state import State @@ -66,7 +77,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta :return: """ state_to_use = state.subset(*reducer.writes) - new_state = reducer.update(result, state_to_use) + new_state = reducer.update(result, state_to_use).subset(*reducer.writes) keys_in_new_state = set(new_state.keys()) extra_keys = keys_in_new_state - set(reducer.writes) if extra_keys: @@ -440,7 +451,7 @@ def with_entrypoint(self, action: str) -> "ApplicationBuilder": self.start = action return self - def with_actions(self, **actions: Action) -> "ApplicationBuilder": + def with_actions(self, **actions: Union[Action, Callable]) -> "ApplicationBuilder": """Adds an action to the application. The actions are granted names (using the with_name) method post-adding, using the kw argument. Thus, this is the only supported way to add actions. diff --git a/burr/integrations/streamlit.py b/burr/integrations/streamlit.py index 4e4540aba..5b561f7ee 100644 --- a/burr/integrations/streamlit.py +++ b/burr/integrations/streamlit.py @@ -4,6 +4,7 @@ from typing import List, Optional from burr.core import Application +from burr.core.action import FunctionBasedAction from burr.integrations.base import require_plugin from burr.integrations.hamilton import Hamilton, StateSource @@ -176,6 +177,7 @@ def render_action(state: AppState): st.header(f"`{current_node}`") action_object = actions[current_node] is_hamilton = isinstance(action_object, Hamilton) + is_function_api = isinstance(action_object, FunctionBasedAction) def format_read(var): out = f"- `{var}`" @@ -210,6 +212,9 @@ def format_write(var): if is_hamilton: digraph = action_object.visualize_step(show_legend=False) st.graphviz_chart(digraph, use_container_width=False) + elif is_function_api: + code = inspect.getsource(action_object.fn) + st.code(code, language="python") else: code = inspect.getsource(action_object.__class__) st.code(code, language="python") diff --git a/examples/counter/application.py b/examples/counter/application.py index 9fe3db26a..f2ff8872f 100644 --- a/examples/counter/application.py +++ b/examples/counter/application.py @@ -1,31 +1,22 @@ +from typing import Tuple + import burr.core -from burr.core import Action, Result, State, default, expr +from burr.core import Result, State, default, expr +from burr.core.action import action from burr.lifecycle import StateAndResultsFullLogger -class CounterAction(Action): - @property - def reads(self) -> list[str]: - return ["counter"] - - def run(self, state: State) -> dict: - return {"counter": state["counter"] + 1} - - @property - def writes(self) -> list[str]: - return ["counter"] - - def update(self, result: dict, state: State) -> State: - return state.update(**result) +@action(reads=["counter"], writes=["counter"]) +def counter(state: State) -> Tuple[dict, State]: + result = {"counter": state["counter"] + 1} + return result, state.update(**result) def application(count_up_to: int = 10, log_file: str = None): return ( burr.core.ApplicationBuilder() - .with_state( - counter=0, - ) - .with_actions(counter=CounterAction(), result=Result(["counter"])) + .with_state(counter=0) + .with_actions(counter=counter, result=Result(["counter"])) .with_transitions( ("counter", "counter", expr(f"counter < {count_up_to}")), ("counter", "result", default), diff --git a/examples/cowsay/application.py b/examples/cowsay/application.py index 7f7f014d5..cd85f65c5 100644 --- a/examples/cowsay/application.py +++ b/examples/cowsay/application.py @@ -1,54 +1,14 @@ import random import time -from typing import List, Optional +from typing import Tuple import cowsay from burr.core import Action, Application, ApplicationBuilder, State, default, expr +from burr.core.action import action from burr.lifecycle import PostRunStepHook -class CowSay(Action): - def __init__(self, say_what: List[Optional[str]]): - super(CowSay, self).__init__() - self.say_what = say_what - - @property - def reads(self) -> list[str]: - return [] - - def run(self, state: State) -> dict: - say_what = random.choice(self.say_what) - return { - "cow_said": cowsay.get_output_string("cow", say_what) if say_what is not None else None - } - - @property - def writes(self) -> list[str]: - return ["cow_said"] - - def update(self, result: dict, state: State) -> State: - return state.update(**result) - - -class CowShouldSay(Action): - @property - def reads(self) -> list[str]: - return [] - - def run(self, state: State) -> dict: - if not random.randint(0, 3): - return {"cow_should_speak": True} - return {"cow_should_speak": False} - - @property - def writes(self) -> list[str]: - return ["cow_should_speak"] - - def update(self, result: dict, state: State) -> State: - return state.update(**result) - - class PrintWhatTheCowSaid(PostRunStepHook): def post_run_step(self, *, state: "State", action: "Action", **future_kwargs): if action.name != "cow_should_say" and state["cow_said"] is not None: @@ -65,6 +25,19 @@ def post_run_step(self, *, state: "State", action: "Action", **future_kwargs): time.sleep(self.sleep_time) +@action(reads=[], writes=["cow_said"]) +def cow_said(state: State, say_what: list[str]) -> Tuple[dict, State]: + said = random.choice(say_what) + result = {"cow_said": cowsay.get_output_string("cow", said) if say_what is not None else None} + return result, state.update(**result) + + +@action(reads=[], writes=["cow_should_speak"]) +def cow_should_speak(state: State) -> Tuple[dict, State]: + result = {"cow_should_speak": random.randint(0, 3) == 0} + return result, state.update(**result) + + def application(in_terminal: bool = False) -> Application: hooks = ( [ @@ -76,21 +49,21 @@ def application(in_terminal: bool = False) -> Application: ) return ( ApplicationBuilder() - .with_state( - cow_said=None, - ) + .with_state(cow_said=None) .with_actions( - say_nothing=CowSay([None]), - say_hello=CowSay(["Hello world!", "What's up?", "Are you Aaron Burr, sir?"]), - cow_should_say=CowShouldSay(), + say_nothing=cow_said.bind(say_what=None), + say_hello=cow_said.bind( + say_what=["Hello world!", "What's up?", "Are you Aaron Burr, sir?"] + ), + cow_should_speak=cow_should_speak, ) .with_transitions( - ("cow_should_say", "say_hello", expr("cow_should_speak")), - ("say_hello", "cow_should_say", default), - ("cow_should_say", "say_nothing", expr("not cow_should_speak")), - ("say_nothing", "cow_should_say", default), + ("cow_should_speak", "say_hello", expr("cow_should_speak")), + ("say_hello", "cow_should_speak", default), + ("cow_should_speak", "say_nothing", expr("not cow_should_speak")), + ("say_nothing", "cow_should_speak", default), ) - .with_entrypoint("cow_should_say") + .with_entrypoint("cow_should_speak") .with_hooks(*hooks) .build() ) @@ -100,4 +73,4 @@ def application(in_terminal: bool = False) -> Application: app = application(in_terminal=True) app.visualize(output_file_path="cowsay.png", include_conditions=True, view=True) while True: - state, result, action = app.step() + s, r, action = app.step() diff --git a/examples/gpt/application.py b/examples/gpt/application.py index 97aa53039..b0c0a3246 100644 --- a/examples/gpt/application.py +++ b/examples/gpt/application.py @@ -1,58 +1,35 @@ -import abc import functools -from typing import List +from typing import List, Tuple import dag import openai -from burr.core import Action, Application, ApplicationBuilder, State, default, expr, when +from burr.core import Application, ApplicationBuilder, State, default, expr, when +from burr.core.action import action from burr.integrations.hamilton import Hamilton, append_state, from_state, update_state from burr.lifecycle import LifecycleAdapter from hamilton import driver +MODES = { + "answer_question": "text", + "generate_image": "image", + "generate_code": "code", + "unknown": "text", +} -class PromptInput(Action): - @property - def reads(self) -> list[str]: - return ["prompt"] - def run(self, state: State) -> dict: - return {"processed_prompt": {"role": "user", "content": state["prompt"], "type": "text"}} - - @property - def writes(self) -> list[str]: - return ["chat_history"] - - def update(self, result: dict, state: State) -> State: - return state.wipe(keep=["prompt", "chat_history"]).append( - chat_history=result["processed_prompt"] - ) - - -class SafetyCheck(Action): - @property - def reads(self) -> list[str]: - return ["prompt"] - - def run(self, state: State) -> dict: - if "unsafe" in state["prompt"]: - # quick for testing - return {"safe": False} - return {"safe": True} - - @property - def writes(self) -> list[str]: - return ["safe"] - - def update(self, result: dict, state: State) -> State: - return state.update(safe=result["safe"]) +@action(reads=["prompt"], writes=["chat_history"]) +def process_prompt(state: State) -> Tuple[dict, State]: + result = {"processed_prompt": {"role": "user", "content": state["prompt"], "type": "text"}} + return result, state.wipe(keep=["prompt", "chat_history"]).append( + chat_history=result["processed_prompt"] + ) -MODES = [ - "answer_question", - "draw_image", - "generate_code", -] +@action(reads=["prompt"], writes=["safe"]) +def check_safety(state: State) -> Tuple[dict, State]: + result = {"safe": "unsafe" not in state["prompt"]} # quick hack to demonstrate + return result, state.update(safe=result["safe"]) @functools.lru_cache(maxsize=None) @@ -60,84 +37,51 @@ def _get_openai_client(): return openai.Client() -class ChooseMode(Action): - def __init__( - self, client: openai.Client = _get_openai_client(), model: str = "gpt-4", modes=tuple(MODES) - ): - super(ChooseMode, self).__init__() - self.client = client - self.model = model - self.modes = modes - - @property - def reads(self) -> list[str]: - return ["prompt"] - - def run(self, state: State) -> dict: - prompt = ( - f"You are a chatbot. You've been prompted this: {state['prompt']}. " - f"You have the capability of responding in the following modes: {', '.join(self.modes)}. " - "Please respond with *only* a single word representing the mode that most accurately" - " corresponds to the prompt. Fr instance, if the prompt is 'draw a picture of a cat', " - "the mode would be 'image'. If the prompt is 'what is the capital of France', the mode would be 'text'." - "If none of these modes apply, please respond with 'unknown'." - ) - result = self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": prompt}, - ], - ) - content = result.choices[0].message.content - mode = content.lower() - if mode not in self.modes: - mode = "unknown" - return {"mode": mode} - - @property - def writes(self) -> list[str]: - return ["mode"] - - def update(self, result: dict, state: State) -> State: - return state.update(mode=result["mode"]) - - -class BaseChatCompletion(Action, abc.ABC): - @property - def reads(self) -> list[str]: - return ["prompt", "chat_history"] - - @abc.abstractmethod - def chat_response(self, state: State) -> dict: - pass - - def run(self, state: State) -> dict: - return {"response": self.chat_response(state)} - - @property - def writes(self) -> list[str]: - return ["response"] - - def update(self, result: dict, state: State) -> State: - return state.update(**result) - - -class DontKnowResponse(BaseChatCompletion): - def __init__(self, modes=tuple(MODES)): - super(DontKnowResponse, self).__init__() - self.modes = modes +@action(reads=["prompt"], writes=["mode"]) +def choose_mode(state: State) -> Tuple[dict, State]: + prompt = ( + f"You are a chatbot. You've been prompted this: {state['prompt']}. " + f"You have the capability of responding in the following modes: {', '.join(MODES)}. " + "Please respond with *only* a single word representing the mode that most accurately" + " corresponds to the prompt. Fr instance, if the prompt is 'draw a picture of a cat', " + "the mode would be 'generate_image'. If the prompt is 'what is the capital of France', the mode would be 'answer_question'." + "If none of these modes apply, please respond with 'unknown'." + ) - def chat_response(self, state: State) -> dict: - return { - "content": f"None of the response modes I support: ({','.join(self.modes)}) " + result = _get_openai_client().chat.completions.create( + model="gpt-4", + messages=[ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt}, + ], + ) + content = result.choices[0].message.content + mode = content.lower() + if mode not in MODES: + mode = "unknown" + result = {"mode": mode} + return result, state.update(**result) + + +@action(reads=["prompt", "chat_history"], writes=["response"]) +def prompt_for_more(state: State) -> Tuple[dict, State]: + result = { + "response": { + "content": f"None of the response modes I support: ({','.join(MODES)}) " f"apply to your question. Please clarify?", "type": "text", "role": "assistant", } + } + return result, state.update(**result) -def _get_text_response(chat_history: list[dict], model: str, client: openai.Client) -> str: +@action(reads=["prompt", "chat_history", "mode"], writes=["response"]) +def chat_response( + state: State, prepend_prompt: str, display_type: str = "text", model: str = "gpt-3.5-turbo" +) -> Tuple[dict, State]: + chat_history = state["chat_history"].copy() + chat_history[-1]["content"] = f"{prepend_prompt}: {chat_history[-1]['content']}" chat_history_api_format = [ { "role": chat["role"], @@ -145,115 +89,67 @@ def _get_text_response(chat_history: list[dict], model: str, client: openai.Clie } for chat in chat_history ] + client = _get_openai_client() result = client.chat.completions.create( model=model, messages=chat_history_api_format, ) - return result.choices[0].message.content - - -class AnswerQuestionResponse(BaseChatCompletion): - def __init__(self, client: openai.Client = _get_openai_client(), model: str = "gpt-4"): - super(AnswerQuestionResponse, self).__init__() - self.client = client - self.model = model - - def chat_response(self, state: State) -> dict: - chat_history = state["chat_history"].copy() - chat_history[-1][ - "content" - ] = f"Please answer the following question: {chat_history[-1]['content']}" - response = _get_text_response(chat_history, self.model, self.client) - return {"content": response, "type": "text", "role": "assistant"} - - -class GenerateImageResponse(BaseChatCompletion): - def __init__(self, client: openai.Client = _get_openai_client(), model: str = "dall-e-2"): - super(GenerateImageResponse, self).__init__() - self.client = client - self.model = model + response = result.choices[0].message.content + result = {"response": {"content": response, "type": MODES[state["mode"]], "role": "assistant"}} + return result, state.update(**result) - def chat_response(self, state: State) -> dict: - result = self.client.images.generate( - model=self.model, prompt=state["prompt"], size="1024x1024", quality="standard", n=1 - ) - return {"content": result.data[0].url, "type": "image", "role": "assistant"} - - -class GenerateCodeResponse(BaseChatCompletion): - def __init__(self, client: openai.Client = _get_openai_client(), model: str = "gpt-4"): - super(GenerateCodeResponse, self).__init__() - self.client = client - self.model = model - def chat_response(self, state: State) -> dict: - chat_history = state["chat_history"].copy() - chat_history[-1]["content"] = ( - f"Please answer the following question, " - f"responding *only* with code, and nothing else: {chat_history[-1]['content']}" - ) - return { - "content": _get_text_response(state["chat_history"], self.model, self.client), - "type": "code", - "role": "assistant", - } - - -class Response(Action): - @property - def reads(self) -> list[str]: - return ["response", "safe", "mode"] - - def run(self, state: State) -> dict: - if not state["safe"]: - return { - "processed_response": { - "role": "assistant", - "content": "I'm sorry, I can't respond to that.", - "type": "text", - } +@action(reads=["prompt", "chat_history", "mode"], writes=["response"]) +def image_response(state: State, model: str = "dall-e-2") -> Tuple[dict, State]: + client = _get_openai_client() + result = client.images.generate( + model=model, prompt=state["prompt"], size="1024x1024", quality="standard", n=1 + ) + response = result.data[0].url + result = {"response": {"content": response, "type": MODES[state["mode"]], "role": "assistant"}} + return result, state.update(**result) + + +@action(reads=["response", "safe", "mode"], writes=["chat_history"]) +def response(state: State) -> Tuple[dict, State]: + if not state["safe"]: + result = { + "processed_response": { + "role": "assistant", + "content": "I'm sorry, I can't respond to that.", + "type": "text", } - return {"processed_response": state["response"]} - - @property - def writes(self) -> list[str]: - return ["chat_history"] - - def update(self, result: dict, state: State) -> State: - return state.append(chat_history=result["processed_response"]) - - -class Error(Action): - @property - def reads(self) -> list[str]: - return ["error"] - - def run(self, state: State) -> dict: - return { - "chat_record": {"role": "assistant", "content": str(state["error"]), "type": "error"} } + else: + result = {"processed_response": state["response"]} + return result, state.append(chat_history=result["processed_response"]) - @property - def writes(self) -> list[str]: - return ["chat_history"] - def update(self, result: dict, state: State) -> State: - return state.append(chat_history=result["chat_record"]) +@action(reads=["error"], writes=["chat_history"]) +def error(state: State) -> Tuple[dict, State]: + result = {"chat_record": {"role": "assistant", "content": str(state["error"]), "type": "error"}} + return result, state.append(chat_history=result["chat_record"]) -def base_application(hooks: List[LifecycleAdapter] = []): +def base_application(hooks: List[LifecycleAdapter] = None): + if hooks is None: + hooks = [] return ( ApplicationBuilder() .with_actions( - prompt=PromptInput(), - check_safety=SafetyCheck(), - decide_mode=ChooseMode(), - generate_image=GenerateImageResponse(), - generate_code=GenerateCodeResponse(), - answer_question=AnswerQuestionResponse(), - prompt_for_more=DontKnowResponse(), - response=Response(), - error=Error(), + prompt=process_prompt, + check_safety=check_safety, + decide_mode=choose_mode, + generate_image=image_response, + generate_code=chat_response.bind( + prepend_prompt="Please respond with *only* code and no other text (at all) to the following:", + ), + answer_question=chat_response.bind( + prepend_prompt="Please answer the following question:", + ), + prompt_for_more=prompt_for_more, + response=response, + error=error, ) .with_entrypoint("prompt") .with_state(chat_history=[]) @@ -283,7 +179,9 @@ def base_application(hooks: List[LifecycleAdapter] = []): ) -def hamilton_application(hooks: List[LifecycleAdapter] = []): +def hamilton_application(hooks: List[LifecycleAdapter] = None): + if hooks is None: + hooks = [] dr = driver.Driver({"provider": "openai"}, dag) # TODO -- add modules Hamilton.set_driver(dr) application = ( @@ -328,7 +226,7 @@ def hamilton_application(hooks: List[LifecycleAdapter] = []): }, outputs={"processed_response": append_state("chat_history")}, ), - error=Error(), + error=error, ) .with_transitions( ("prompt", "check_safety", default), @@ -364,7 +262,7 @@ def application(use_hamilton: bool, hooks: List[LifecycleAdapter] = []) -> Appli if __name__ == "__main__": - app = application(use_hamilton=True) + app = application(use_hamilton=False) # state, result = app.run(until=["result"]) app.visualize(output_file_path="gpt", include_conditions=False, view=True, format="png") # assert state["counter"] == 10 diff --git a/examples/gpt/streamlit_app.py b/examples/gpt/streamlit_app.py index 9e1d662d2..f9e4ee7b5 100644 --- a/examples/gpt/streamlit_app.py +++ b/examples/gpt/streamlit_app.py @@ -11,7 +11,6 @@ set_slider_to_current, update_state, ) -from burr.lifecycle.default import SlowDownHook def render_chat_message(record: Record): @@ -34,7 +33,7 @@ def render_chat_message(record: Record): def retrieve_state(): if "burr_state" not in st.session_state: state = AppState.from_empty( - app=chatbot_application.application(use_hamilton=True, hooks=[SlowDownHook(0.0, 0)]), + app=chatbot_application.application(use_hamilton=False), ) else: state = get_state()