Skip to content

Commit

Permalink
WIP for function-based actions
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Feb 7, 2024
1 parent 2b814d8 commit 71ea2d8
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 297 deletions.
7 changes: 4 additions & 3 deletions burr/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from burr.core.state import State

__all__ = [
"action",
"Action",
"Application",
"ApplicationBuilder",
"Condition",
"Result",
"default",
"when",
"expr",
"Application",
"Result",
"State",
"when",
]
55 changes: 49 additions & 6 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import ast
import copy
import inspect
from typing import Callable, List, Tuple, Union
import types
from typing import Any, Callable, List, Protocol, Tuple, TypeVar, Union

from burr.core.state import State

Expand Down Expand Up @@ -169,7 +170,11 @@ class FunctionBasedAction(Action):
ACTION_FUNCTION = "action_function"

def __init__(
self, fn: Callable[[State], Tuple[dict, State]], reads: List[str], writes: List[str]
self,
fn: Callable[..., Tuple[dict, State]],
reads: List[str],
writes: List[str],
bound_params: dict = None,
):
"""Instantiates a function-based action with the given function, reads, and writes.
The function must take in a state and return a tuple of (result, new_state).
Expand All @@ -183,13 +188,18 @@ def __init__(
self._reads = reads
self._writes = writes
self._state_created = None
self._bound_params = bound_params if bound_params is not None else {}

@property
def fn(self) -> Callable:
return self._fn

@property
def reads(self) -> list[str]:
return self._reads

def run(self, state: State) -> dict:
result, new_state = self._fn(state)
result, new_state = self._fn(state, **self._bound_params)
self._state_created = new_state
return result

Expand All @@ -202,7 +212,23 @@ def update(self, result: dict, state: State) -> State:
raise ValueError(
"FunctionBasedAction.run must be called before FunctionBasedAction.update"
)
return self._state_created
# TODO -- validate that all the keys are contained -- fix up subset to handle this
# TODO -- validate that we've (a) written only to the write ones (by diffing the read ones),
# and (b) written to no more than the write ones
return self._state_created.subset(*self._writes)

def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
"""Binds parameters to the function.
Note that there is no reason to call this by the user. This *could*
be done at the class level, but given that API allows for constructor parameters
(which do the same thing in a cleaner way), it is best to keep it here for now.
:param kwargs:
:return:
"""
new_action = copy.copy(self)
new_action._bound_params = {**self._bound_params, **kwargs}
return new_action


def _validate_action_function(fn: Callable):
Expand All @@ -225,7 +251,23 @@ def _validate_action_function(fn: Callable):
)


def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Callable]:
C = TypeVar("C", bound=Callable) # placeholder for any Callable


class FunctionRepresentingAction(Protocol[C]):
action_function: FunctionBasedAction
__call__: C

def bind(self, **kwargs: Any):
...


def bind(self: FunctionRepresentingAction, **kwargs: Any) -> FunctionRepresentingAction:
self.action_function = self.action_function.with_params(**kwargs)
return self


def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]:
"""Decorator to create a function-based action. This is user-facing.
Note that, in the future, with typed state, we may not need this for
all cases.
Expand All @@ -235,8 +277,9 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Callable
:return: The decorator to assign the function as an action
"""

def decorator(fn: Callable) -> Callable:
def decorator(fn) -> FunctionRepresentingAction:
setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes))
setattr(fn, "bind", types.MethodType(bind, fn))
return fn

return decorator
Expand Down
17 changes: 14 additions & 3 deletions burr/core/application.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import collections
import dataclasses
import logging
from typing import Any, AsyncGenerator, Generator, List, Literal, Optional, Set, Tuple, Union
from typing import (
Any,
AsyncGenerator,
Callable,
Generator,
List,
Literal,
Optional,
Set,
Tuple,
Union,
)

from burr.core.action import Action, Condition, Function, Reducer, create_action, default
from burr.core.state import State
Expand Down Expand Up @@ -66,7 +77,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta
:return:
"""
state_to_use = state.subset(*reducer.writes)
new_state = reducer.update(result, state_to_use)
new_state = reducer.update(result, state_to_use).subset(*reducer.writes)
keys_in_new_state = set(new_state.keys())
extra_keys = keys_in_new_state - set(reducer.writes)
if extra_keys:
Expand Down Expand Up @@ -440,7 +451,7 @@ def with_entrypoint(self, action: str) -> "ApplicationBuilder":
self.start = action
return self

def with_actions(self, **actions: Action) -> "ApplicationBuilder":
def with_actions(self, **actions: Union[Action, Callable]) -> "ApplicationBuilder":
"""Adds an action to the application. The actions are granted names (using the with_name)
method post-adding, using the kw argument. Thus, this is the only supported way to add actions.
Expand Down
5 changes: 5 additions & 0 deletions burr/integrations/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional

from burr.core import Application
from burr.core.action import FunctionBasedAction
from burr.integrations.base import require_plugin
from burr.integrations.hamilton import Hamilton, StateSource

Expand Down Expand Up @@ -176,6 +177,7 @@ def render_action(state: AppState):
st.header(f"`{current_node}`")
action_object = actions[current_node]
is_hamilton = isinstance(action_object, Hamilton)
is_function_api = isinstance(action_object, FunctionBasedAction)

def format_read(var):
out = f"- `{var}`"
Expand Down Expand Up @@ -210,6 +212,9 @@ def format_write(var):
if is_hamilton:
digraph = action_object.visualize_step(show_legend=False)
st.graphviz_chart(digraph, use_container_width=False)
elif is_function_api:
code = inspect.getsource(action_object.fn)
st.code(code, language="python")
else:
code = inspect.getsource(action_object.__class__)
st.code(code, language="python")
Expand Down
29 changes: 10 additions & 19 deletions examples/counter/application.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
from typing import Tuple

import burr.core
from burr.core import Action, Result, State, default, expr
from burr.core import Result, State, default, expr
from burr.core.action import action
from burr.lifecycle import StateAndResultsFullLogger


class CounterAction(Action):
@property
def reads(self) -> list[str]:
return ["counter"]

def run(self, state: State) -> dict:
return {"counter": state["counter"] + 1}

@property
def writes(self) -> list[str]:
return ["counter"]

def update(self, result: dict, state: State) -> State:
return state.update(**result)
@action(reads=["counter"], writes=["counter"])
def counter(state: State) -> Tuple[dict, State]:
result = {"counter": state["counter"] + 1}
return result, state.update(**result)


def application(count_up_to: int = 10, log_file: str = None):
return (
burr.core.ApplicationBuilder()
.with_state(
counter=0,
)
.with_actions(counter=CounterAction(), result=Result(["counter"]))
.with_state(counter=0)
.with_actions(counter=counter, result=Result(["counter"]))
.with_transitions(
("counter", "counter", expr(f"counter < {count_up_to}")),
("counter", "result", default),
Expand Down
81 changes: 27 additions & 54 deletions examples/cowsay/application.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,14 @@
import random
import time
from typing import List, Optional
from typing import Tuple

import cowsay

from burr.core import Action, Application, ApplicationBuilder, State, default, expr
from burr.core.action import action
from burr.lifecycle import PostRunStepHook


class CowSay(Action):
def __init__(self, say_what: List[Optional[str]]):
super(CowSay, self).__init__()
self.say_what = say_what

@property
def reads(self) -> list[str]:
return []

def run(self, state: State) -> dict:
say_what = random.choice(self.say_what)
return {
"cow_said": cowsay.get_output_string("cow", say_what) if say_what is not None else None
}

@property
def writes(self) -> list[str]:
return ["cow_said"]

def update(self, result: dict, state: State) -> State:
return state.update(**result)


class CowShouldSay(Action):
@property
def reads(self) -> list[str]:
return []

def run(self, state: State) -> dict:
if not random.randint(0, 3):
return {"cow_should_speak": True}
return {"cow_should_speak": False}

@property
def writes(self) -> list[str]:
return ["cow_should_speak"]

def update(self, result: dict, state: State) -> State:
return state.update(**result)


class PrintWhatTheCowSaid(PostRunStepHook):
def post_run_step(self, *, state: "State", action: "Action", **future_kwargs):
if action.name != "cow_should_say" and state["cow_said"] is not None:
Expand All @@ -65,6 +25,19 @@ def post_run_step(self, *, state: "State", action: "Action", **future_kwargs):
time.sleep(self.sleep_time)


@action(reads=[], writes=["cow_said"])
def cow_said(state: State, say_what: list[str]) -> Tuple[dict, State]:
said = random.choice(say_what)
result = {"cow_said": cowsay.get_output_string("cow", said) if say_what is not None else None}
return result, state.update(**result)


@action(reads=[], writes=["cow_should_speak"])
def cow_should_speak(state: State) -> Tuple[dict, State]:
result = {"cow_should_speak": random.randint(0, 3) == 0}
return result, state.update(**result)


def application(in_terminal: bool = False) -> Application:
hooks = (
[
Expand All @@ -76,21 +49,21 @@ def application(in_terminal: bool = False) -> Application:
)
return (
ApplicationBuilder()
.with_state(
cow_said=None,
)
.with_state(cow_said=None)
.with_actions(
say_nothing=CowSay([None]),
say_hello=CowSay(["Hello world!", "What's up?", "Are you Aaron Burr, sir?"]),
cow_should_say=CowShouldSay(),
say_nothing=cow_said.bind(say_what=None),
say_hello=cow_said.bind(
say_what=["Hello world!", "What's up?", "Are you Aaron Burr, sir?"]
),
cow_should_speak=cow_should_speak,
)
.with_transitions(
("cow_should_say", "say_hello", expr("cow_should_speak")),
("say_hello", "cow_should_say", default),
("cow_should_say", "say_nothing", expr("not cow_should_speak")),
("say_nothing", "cow_should_say", default),
("cow_should_speak", "say_hello", expr("cow_should_speak")),
("say_hello", "cow_should_speak", default),
("cow_should_speak", "say_nothing", expr("not cow_should_speak")),
("say_nothing", "cow_should_speak", default),
)
.with_entrypoint("cow_should_say")
.with_entrypoint("cow_should_speak")
.with_hooks(*hooks)
.build()
)
Expand All @@ -100,4 +73,4 @@ def application(in_terminal: bool = False) -> Application:
app = application(in_terminal=True)
app.visualize(output_file_path="cowsay.png", include_conditions=True, view=True)
while True:
state, result, action = app.step()
s, r, action = app.step()
Loading

0 comments on commit 71ea2d8

Please sign in to comment.