From ddd675a89d5b0c82c9c25c888a27d46dba913864 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 4 Dec 2024 09:13:18 -0800 Subject: [PATCH] Additional updates to parallelism docs More updates to make API consistent: 1. Showing inputs for class-based action (repeat but good to hammer home) 2. Making API consistent --- docs/concepts/actions.rst | 5 ++++ docs/concepts/parallelism.rst | 43 ++++++++++++++++++----------------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/docs/concepts/actions.rst b/docs/concepts/actions.rst index d871f5ef..d5aaff0d 100644 --- a/docs/concepts/actions.rst +++ b/docs/concepts/actions.rst @@ -30,6 +30,8 @@ There are two APIs for defining actions: class-based and function-based. They ar - use the function-based API when you want to write something quick and terse that reads from a fixed set of state variables - use the class-based API when you want to leverage inheritance or parameterize the action in more powerful ways +.. _functionbasedactions: + ---------------------- Function-based actions ---------------------- @@ -127,6 +129,9 @@ injected into your Burr Actions. This is done by adding ``__context`` to the act Class-Based Actions ------------------- +.. _classbasedactions: + + You can define an action by implementing the :py:class:`Action ` class: .. code-block:: python diff --git a/docs/concepts/parallelism.rst b/docs/concepts/parallelism.rst index 0b3a4913..9253f068 100644 --- a/docs/concepts/parallelism.rst +++ b/docs/concepts/parallelism.rst @@ -81,12 +81,12 @@ This looks as follows -- in this case we're running the same LLM over different class TestMultiplePrompts(MapStates): - def action(self) -> Action | Callable | RunnableGraph: + def action(self, state: State, inputs: Dict[str, Any]) -> Action | Callable | RunnableGraph: # make sure to add a name to the action # This is not necessary for subgraphs, as actions will already have names return query_llm.with_name("query_llm") - def states(self, state: State) -> Generator[State, None, None]: + def states(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[State, None, None]: # You could easily have a list_prompts upstream action that writes to "prompts" in state # And loop through those # This hardcodes for simplicity @@ -146,7 +146,7 @@ For case (2) (mapping actions over the same state) you implement the ``MapAction from burr.core import action, state from burr.core.parallelism import MapActions, RunnableGraph - from typing import Callable, Generator, List + from typing import Callable, Generator, List, Dict, Any @action(reads=["prompt", "model"], writes=["llm_output"]) def query_llm(state: State, model: str) -> State: @@ -155,7 +155,7 @@ For case (2) (mapping actions over the same state) you implement the ``MapAction class TestMultipleModels(MapActions): - def actions(self, state: State) -> Generator[Action | Callable | RunnableGraph, None, None]: + def actions(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[Action | Callable | RunnableGraph, None, None]: # Make sure to add a name to the action if you use bind() with a function, # note that these can be different actions, functions, etc... # in this case we're using `.bind()` to create multiple actions, but we can use some mix of @@ -167,7 +167,7 @@ For case (2) (mapping actions over the same state) you implement the ``MapAction ] yield action - def state(self, state: State) -> State: + def state(self, state: State, inputs: Dict[str, Any]) -> State: return state.update(prompt="What is the meaning of life?") def reduce(self, state: State, states: Generator[State, None, None]) -> State: @@ -314,7 +314,7 @@ This might look as follows -- say we have a simple subflow that takes in a raw p def action(self, state: State, inputs: Dict[str, Any]) -> Action | Callable | RunnableGraph: return runnable_graph - def states(self, state: State) -> Generator[State, None, None]: + def states(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[State, None, None]: for prompt in [ "What is the meaning of life?", "What is the airspeed velocity of an unladen swallow?", @@ -332,15 +332,11 @@ it can run just as the single prompt we did above. Note this is also doable for Passing inputs -------------- -.. note:: - - Should ``MapOverInputs`` be its own class? Or should we have ``bind_from_state(prompt="prompt_field_in_state")`` that allows you to pass it in as - state and just use the mapping capabilities? - -Each of these can (optionally) produce ``inputs`` by yielding/returning a tuple from the ``states``/``actions`` function. - -This is useful if you want to vary the inputs. Note this is the same as passing ``inputs=`` to ``app.run``. +Parallel actions can accept inputs in the same way that class-based actions do. In order to accept inputs you have to declare them in the class. As we're using the :ref:`class-based API `, +this is done by declaring the ``inputs`` property -- a list of strings that are used in inputs. Note you have to use the superclasses +inputs as well to ensure it has everything it needs -- we will likely be automating this. +This looks as follows: .. code-block:: python @@ -379,17 +375,22 @@ This is useful if you want to vary the inputs. Note this is the same as passing def action(self) -> Action | Callable | RunnableGraph: return runnable_graph - def states(self, state: State) -> Generator[Tuple[State, dict], None, None]: - for prompt in [ - "What is the meaning of life?", - "What is the airspeed velocity of an unladen swallow?", - "What is the best way to cook a steak?", - ]: - yield state.update(prompt=prompt), {"model": "gpt-4"} # pass in the model as an input + @property + def inputs(self) -> List[str]: + return ["prompts"] + super().inputs # make sure to include the superclass inputs + + def states(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[Tuple[State, dict], None, None]: + for prompt in inputs["prompts"]: + yield state.update(prompt=prompt) ... # same as above +.. note:: + + Should ``MapOverInputs`` be its own class? Or should we have ``bind_from_state(prompt="prompt_field_in_state")`` that allows you to pass it in as + state and just use the mapping capabilities? Or are we happy as it currently is because we can pass in inputs through `MapStates`/`MapActions` (as shown above). + Lower-level API ===============