From 5e3e18f7a6b09eb62074338d395847baa133e6e6 Mon Sep 17 00:00:00 2001 From: jernejfrank Date: Mon, 30 Dec 2024 21:04:29 +0800 Subject: [PATCH] Add local async persisters and tests Prototyping and testing async persisters: - Add AsyncDevNull and AsyncInMemory persisters for tests - Added support for async sqlite persister - Test Async persister interface, async builder, async application --- burr/core/persistence.py | 317 +++++++++++++++++++++++++++++++++ pyproject.toml | 8 +- tests/core/test_application.py | 118 +++++++++++- tests/core/test_persistence.py | 266 ++++++++++++++++++++++++++- tests/test_end_to_end.py | 116 +++++++++++- 5 files changed, 817 insertions(+), 8 deletions(-) diff --git a/burr/core/persistence.py b/burr/core/persistence.py index 7e3450ed..d6574540 100644 --- a/burr/core/persistence.py +++ b/burr/core/persistence.py @@ -6,6 +6,8 @@ from collections import defaultdict from typing import Any, Dict, Literal, Optional, TypedDict +import aiosqlite + from burr.common.types import BaseCopyable from burr.core import Action from burr.core.state import State, logger @@ -261,6 +263,30 @@ def save( return +class AsyncDevNullPersister(AsyncBaseStatePersister): + """Does nothing asynchronously, do not use this. This is for testing only.""" + + async def load( + self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs + ) -> Optional[PersistedStateData]: + return None + + async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: + return [] + + async def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: State, + status: Literal["completed", "failed"], + **kwargs, + ): + return + + class SQLitePersister(BaseStatePersister, BaseCopyable): """Class for SQLite persistence of state. This is a simple implementation.""" @@ -476,6 +502,243 @@ def __setstate__(self, state): ) +class AsyncSQLitePersister(AsyncBaseStatePersister, BaseCopyable): + """Class for asynchronous SQLite persistence of state. This is a simple implementation. + + SQLite is specifically single-threaded and `aiosqlite `_ + creates async support through multi-threading. This persister is mainly here for quick prototyping and testing; + we suggest to consider a different database with native async support for production. + + Note the third-party library `aiosqlite `_, + is maintained and considered stable considered stable: https://github.com/omnilib/aiosqlite/issues/309. + """ + + def copy(self) -> "Self": + return AsyncSQLitePersister( + db_path=self.db_path, + table_name=self.table_name, + serde_kwargs=self.serde_kwargs, + connect_kwargs=self._connect_kwargs, + ) + + PARTITION_KEY_DEFAULT = "" + + @classmethod + async def from_values( + cls, + db_path: str, + table_name: str = "burr_state", + serde_kwargs: dict = None, + connect_kwargs: dict = None, + ) -> "AsyncSQLitePersister": + """Creates a new instance of the AsyncSQLitePersister from passed in values. + + :param db_path: the path the DB will be stored. + :param table_name: the table name to store things under. + :param serde_kwargs: kwargs for state serialization/deserialization. + :param connect_kwargs: kwargs to pass to the aiosqlite.connect method. + :return: async sqlite persister instance with an open connection. You are responsible + for closing the connection yourself. + """ + connection = await aiosqlite.connect( + db_path, **connect_kwargs if connect_kwargs is not None else {} + ) + return cls(connection, table_name, serde_kwargs) + + def __init__( + self, + connection, + table_name: str = "burr_state", + serde_kwargs: dict = None, + ): + """Constructor. + + NOTE: you are responsible to handle closing of the connection / teardown manually. To help, + we provide a close() method. + + :param connection: the path the DB will be stored. + :param table_name: the table name to store things under. + :param serde_kwargs: kwargs for state serialization/deserialization. + """ + self.connection = connection + self.table_name = table_name + self.serde_kwargs = serde_kwargs or {} + self._initialized = False + + async def create_table_if_not_exists(self, table_name: str): + """Helper function to create the table where things are stored if it doesn't exist.""" + cursor = await self.connection.cursor() + await cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + partition_key TEXT DEFAULT '{AsyncSQLitePersister.PARTITION_KEY_DEFAULT}', + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + position TEXT NOT NULL, + status TEXT NOT NULL, + state TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (partition_key, app_id, sequence_id, position) + )""" + ) + await cursor.execute( + f""" + CREATE INDEX IF NOT EXISTS {table_name}_created_at_index ON {table_name} (created_at); + """ + ) + await self.connection.commit() + + async def initialize(self): + """Asynchronously creates the table if it doesn't exist""" + # Usage + await self.create_table_if_not_exists(self.table_name) + self._initialized = True + + async def is_initialized(self) -> bool: + """This checks to see if the table has been created in the database or not. + It defaults to using the initialized field, else queries the database to see if the table exists. + It then sets the initialized field to True if the table exists. + """ + if self._initialized: + return True + + cursor = await self.connection.cursor() + await cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (self.table_name,) + ) + self._initialized = await cursor.fetchone() is not None + return self._initialized + + async def list_app_ids(self, partition_key: Optional[str] = None, **kwargs) -> list[str]: + partition_key = ( + partition_key + if partition_key is not None + else AsyncSQLitePersister.PARTITION_KEY_DEFAULT + ) + + cursor = await self.connection.cursor() + await cursor.execute( + f"SELECT DISTINCT app_id FROM {self.table_name} " + f"WHERE partition_key = ? " + f"ORDER BY created_at DESC", + (partition_key,), + ) + app_ids = [row[0] for row in await cursor.fetchall()] + return app_ids + + async def load( + self, + partition_key: Optional[str], + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, + ) -> Optional[PersistedStateData]: + """Asynchronously loads state for a given partition id. + + Depending on the parameters, this will return the last thing written, the last thing written for a given app_id, + or a specific sequence_id for a given app_id. + + :param partition_key: + :param app_id: + :param sequence_id: + :return: + """ + partition_key = ( + partition_key + if partition_key is not None + else AsyncSQLitePersister.PARTITION_KEY_DEFAULT + ) + logger.debug("Loading %s, %s, %s", partition_key, app_id, sequence_id) + cursor = await self.connection.cursor() + if app_id is None: + # get latest for all app_ids + await cursor.execute( + f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} " + f"WHERE partition_key = ? " + f"ORDER BY CREATED_AT DESC LIMIT 1", + (partition_key,), + ) + elif sequence_id is None: + await cursor.execute( + f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} " + f"WHERE partition_key = ? AND app_id = ? " + f"ORDER BY sequence_id DESC LIMIT 1", + (partition_key, app_id), + ) + else: + await cursor.execute( + f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} " + f"WHERE partition_key = ? AND app_id = ? AND sequence_id = ?", + (partition_key, app_id, sequence_id), + ) + row = await cursor.fetchone() + if row is None: + return None + _state = State.deserialize(json.loads(row[1]), **self.serde_kwargs) + return { + "partition_key": partition_key, + "app_id": row[3], + "sequence_id": row[2], + "position": row[0], + "state": _state, + "created_at": row[4], + "status": row[5], + } + + async def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: State, + status: Literal["completed", "failed"], + **kwargs, + ): + """ + Asynchronously saves the state for a given app_id, sequence_id, and position. + + This method connects to the SQLite database, converts the state to a JSON string, and inserts a new record + into the table with the provided partition_key, app_id, sequence_id, position, and state. After the operation, + it commits the changes and closes the connection to the database. + + :param partition_key: The partition key. This could be None, but it's up to the persister to whether + that is a valid value it can handle. + :param app_id: The identifier for the app instance being recorded. + :param sequence_id: The state corresponding to a specific point in time. + :param position: The position in the sequence of states. + :param state: The state to be saved, an instance of the State class. + :param status: The status of this state, either "completed" or "failed". If "failed" the state is what it was + before the action was applied. + :return: None + """ + logger.debug( + "saving %s, %s, %s, %s, %s, %s", + partition_key, + app_id, + sequence_id, + position, + state, + status, + ) + partition_key = ( + partition_key + if partition_key is not None + else AsyncSQLitePersister.PARTITION_KEY_DEFAULT + ) + cursor = await self.connection.cursor() + json_state = json.dumps(state.serialize(**self.serde_kwargs)) + await cursor.execute( + f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) " + f"VALUES (?, ?, ?, ?, ?, ?)", + (partition_key, app_id, sequence_id, position, json_state, status), + ) + await self.connection.commit() + + async def close(self): + await self.connection.close() + + class InMemoryPersister(BaseStatePersister): """In-memory persister for testing purposes. This is not recommended for production use.""" @@ -529,7 +792,61 @@ def save( self._storage[partition_key][app_id].append(persisted_state) +class AsyncInMemoryPersister(AsyncBaseStatePersister): + """Sync in-memory persister for testing purposes. This is not recommended for production use.""" + + def __init__(self): + self._storage = defaultdict(lambda: defaultdict(list)) + + async def load( + self, partition_key: str, app_id: Optional[str], sequence_id: Optional[int] = None, **kwargs + ) -> Optional[PersistedStateData]: + # If no app_id provided, return None + if app_id is None: + return None + + if not (states := self._storage[partition_key][app_id]): + return None + + if sequence_id is None: + return states[-1] + + # Find states matching the specific sequence_id + matching_states = [state for state in states if state["sequence_id"] == sequence_id] + + # Return the latest state for this sequence_id, if exists + return matching_states[-1] if matching_states else None + + async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: + return list(self._storage[partition_key].keys()) + + async def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: State, + status: Literal["completed", "failed"], + **kwargs, + ): + # Create a PersistedStateData entry + persisted_state: PersistedStateData = { + "partition_key": partition_key or "", + "app_id": app_id, + "sequence_id": sequence_id, + "position": position, + "state": state, + "created_at": datetime.datetime.now().isoformat(), + "status": status, + } + + # Store the state + self._storage[partition_key][app_id].append(persisted_state) + + SQLLitePersister = SQLitePersister +AsyncSQLLitePersister = AsyncSQLitePersister if __name__ == "__main__": s = SQLitePersister(db_path=".SQLite.db", table_name="test1") diff --git a/pyproject.toml b/pyproject.toml index f98906c5..695db35f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ streamlit = [ "streamlit", "graphviz", "matplotlib", - "sf-hamilton" + "sf-hamilton", ] hamilton = [ @@ -40,6 +40,10 @@ graphviz = [ "graphviz" ] +sqlite = [ + "aiosqlite" +] + postgresql = [ "psycopg2-binary" ] @@ -61,6 +65,7 @@ tests = [ "pydantic[email]", "pyarrow", "redis", + "aiosqlite", "burr[opentelemetry]", "burr[haystack]", "burr[ray]" @@ -77,6 +82,7 @@ documentation = [ "psycopg2-binary", "redis", "ray", + "aiosqlite", "sphinxcontrib-googleanalytics" ] diff --git a/tests/core/test_application.py b/tests/core/test_application.py index d9767cc3..26191462 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -44,6 +44,7 @@ ) from burr.core.graph import Graph, GraphBuilder, Transition from burr.core.persistence import ( + AsyncDevNullPersister, BaseStatePersister, DevNullPersister, PersistedStateData, @@ -2770,9 +2771,9 @@ def save( def test_application_builder_initialize_raises_on_broken_persistor(): """Persisters should return None when there is no state to be loaded and the default used.""" + counter_action = base_counter_action.with_name("counter") + result_action = Result("count").with_name("result") with pytest.raises(ValueError, match="but value for state was None"): - counter_action = base_counter_action.with_name("counter") - result_action = Result("count").with_name("result") ( ApplicationBuilder() .with_actions(counter_action, result_action) @@ -2787,6 +2788,37 @@ def test_application_builder_initialize_raises_on_broken_persistor(): ) +def test_load_from_sync_cannot_have_async_persistor_error(): + builder = ApplicationBuilder() + builder.initialize_from( + AsyncDevNullPersister(), + resume_at_next_action=True, + default_state={}, + default_entrypoint="foo", + ) + with pytest.raises( + ValueError, match="are building the sync application, but have used an async initializer." + ): + # we have not initialized + builder._load_from_sync_persister() + + +async def test_load_from_async_cannot_have_sync_persistor_error(): + await asyncio.sleep(0.00001) + builder = ApplicationBuilder() + builder.initialize_from( + DevNullPersister(), + resume_at_next_action=True, + default_state={}, + default_entrypoint="foo", + ) + with pytest.raises( + ValueError, match="are building the async application, but have used an sync initializer." + ): + # we have not initialized + await builder._load_from_async_persister() + + def test_application_builder_assigns_correct_actions_with_dual_api(): counter_action = base_counter_action.with_name("counter") result_action = Result("count") @@ -3290,13 +3322,51 @@ def test_builder_captures_typing_system(): assert state.data["count"] == 10 -def test_with_state_persister_is_not_initialized_error(tmp_path): +def test_set_sync_state_persister_cannot_have_async_error(): + builder = ApplicationBuilder() + persister = AsyncDevNullPersister() + builder.with_state_persister(persister) + with pytest.raises( + ValueError, match="are building the sync application, but have used an async persister." + ): + # we have not initialized + builder._set_sync_state_persister() + + +def test_set_sync_state_persister_is_not_initialized_error(tmp_path): builder = ApplicationBuilder() persister = SQLLitePersister(db_path=":memory:", table_name="test_table") + builder.with_state_persister(persister) + with pytest.raises(RuntimeError): + # we have not initialized + builder._set_sync_state_persister() + + +async def test_set_async_state_persister_cannot_have_sync_error(): + await asyncio.sleep(0.00001) + builder = ApplicationBuilder() + persister = DevNullPersister() + builder.with_state_persister(persister) + with pytest.raises( + ValueError, match="are building the async application, but have used an sync persister." + ): + # we have not initialized + await builder._set_async_state_persister() + +async def test_set_async_state_persister_is_not_initialized_error(tmp_path): + await asyncio.sleep(0.00001) + builder = ApplicationBuilder() + + class FakePersister(AsyncDevNullPersister): + async def is_initialized(self): + return False + + persister = FakePersister() + builder.with_state_persister(persister) with pytest.raises(RuntimeError): # we have not initialized - builder.with_state_persister(persister) + await builder._set_async_state_persister() def test_with_state_persister_is_initialized_not_implemented(): @@ -3419,3 +3489,43 @@ def test_remap_context_variable_without_mangled_context(): inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"} assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected + + +async def test_async_application_builder_initialize_raises_on_broken_persistor(): + """Persisters should return None when there is no state to be loaded and the default used.""" + await asyncio.sleep(0.00001) + counter_action = base_counter_action_async.with_name("counter") + result_action = Result("count").with_name("result") + + class AsyncBrokenPersister(AsyncDevNullPersister): + async def load( + self, + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, + ) -> Optional[PersistedStateData]: + await asyncio.sleep(0.0001) + return dict( + partition_key="key", + app_id="id", + sequence_id=0, + position="foo", + state=None, + created_at="", + status="completed", + ) + + with pytest.raises(ValueError, match="but value for state was None"): + await ( + ApplicationBuilder() + .with_actions(counter_action, result_action) + .with_transitions(("counter", "result", default)) + .initialize_from( + AsyncBrokenPersister(), + resume_at_next_action=True, + default_state={}, + default_entrypoint="foo", + ) + .abuild() + ) diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py index 8bd0a57f..592074fd 100644 --- a/tests/core/test_persistence.py +++ b/tests/core/test_persistence.py @@ -1,7 +1,16 @@ +import asyncio +from typing import Tuple + +import aiosqlite import pytest -from burr.core import State -from burr.core.persistence import InMemoryPersister, SQLLitePersister +from burr.core import ApplicationBuilder, State, action +from burr.core.persistence import ( + AsyncInMemoryPersister, + AsyncSQLLitePersister, + InMemoryPersister, + SQLLitePersister, +) @pytest.fixture( @@ -98,3 +107,256 @@ def test_persister_methods_none_partition_key(persistence, method_name: str, kwa # this doesn't guarantee that the results of `partition_key=None` and # `partition_key=persistence.PARTITION_KEY_DEFAULT`. This is hard to test because # these operations are stateful (i.e., read/write to a db) + + +class AsyncSQLLiteContextManager: + def __init__(self, sqlite_object): + self.client = sqlite_object + + async def __aenter__(self): + return self.client + + async def __aexit__(self, exc_type, exc, tb): + await self.client.close() + + +@pytest.fixture( + params=[ + {"which": "sqlite"}, + {"which": "memory"}, + ] +) +async def async_persistence(request): + which = request.param["which"] + if which == "sqlite": + sqlite_persister = await AsyncSQLLitePersister.from_values( + db_path=":memory:", table_name="test_table" + ) + async_context_manager = AsyncSQLLiteContextManager(sqlite_persister) + async with async_context_manager as client: + yield client + elif which == "memory": + yield AsyncInMemoryPersister() + + +async def test_async_persistence_saves_and_loads_state(async_persistence): + await asyncio.sleep(0.00001) + if hasattr(async_persistence, "initialize"): + await async_persistence.initialize() + await async_persistence.save( + "partition_key", "app_id", 1, "position", State({"key": "value"}), "status" + ) + loaded_state = await async_persistence.load("partition_key", "app_id") + assert loaded_state["state"] == State({"key": "value"}) + + +async def test_async_persistence_returns_none_when_no_state(async_persistence): + await asyncio.sleep(0.00001) + if hasattr(async_persistence, "initialize"): + await async_persistence.initialize() + loaded_state = await async_persistence.load("partition_key", "app_id") + assert loaded_state is None + + +async def test_async_persistence_lists_app_ids(async_persistence): + await asyncio.sleep(0.00001) + if hasattr(async_persistence, "initialize"): + await async_persistence.initialize() + await async_persistence.save( + "partition_key", "app_id1", 1, "position", State({"key": "value"}), "status" + ) + await async_persistence.save( + "partition_key", "app_id2", 1, "position", State({"key": "value"}), "status" + ) + app_ids = await async_persistence.list_app_ids("partition_key") + assert set(app_ids) == set(["app_id1", "app_id2"]) + + +@pytest.mark.parametrize( + "method_name,kwargs", + [ + ("list_app_ids", {"partition_key": None}), + ("load", {"partition_key": None, "app_id": "foo"}), + ( + "save", + { + "partition_key": None, + "app_id": "foo", + "sequence_id": 1, + "position": "position", + "state": State({"key": "value"}), + "status": "status", + }, + ), + ], +) +async def test_async_persister_methods_none_partition_key( + async_persistence, method_name: str, kwargs: dict +): + await asyncio.sleep(0.00001) + if hasattr(async_persistence, "initialize"): + await async_persistence.initialize() + method = getattr(async_persistence, method_name) + # method can be executed with `partition_key=None` + await method(**kwargs) + # this doesn't guarantee that the results of `partition_key=None` and + # `partition_key=persistence.PARTITION_KEY_DEFAULT`. This is hard to test because + # these operations are stateful (i.e., read/write to a db) + + +async def test_AsyncSQLLitePersister_from_values(): + await asyncio.sleep(0.00001) + connection = await aiosqlite.connect(":memory:") + sqlite_persister_init = AsyncSQLLitePersister(connection=connection, table_name="test_table") + sqlite_persister_from_values = await AsyncSQLLitePersister.from_values( + db_path=":memory:", table_name="test_table" + ) + + try: + sqlite_persister_init.connection == sqlite_persister_from_values.connection + except Exception as e: + raise e + finally: + await sqlite_persister_init.close() + await sqlite_persister_from_values.close() + + +async def test_AsyncSQLLitePersister_connection_shutdown(): + await asyncio.sleep(0.00001) + sqlite_persister = await AsyncSQLLitePersister.from_values( + db_path=":memory:", table_name="test_table" + ) + await sqlite_persister.close() + + +@pytest.fixture() +async def initializing_async_persistence(): + sqlite_persister = await AsyncSQLLitePersister.from_values( + db_path=":memory:", table_name="test_table" + ) + async_context_manager = AsyncSQLLiteContextManager(sqlite_persister) + async with async_context_manager as client: + yield client + + +async def test_async_persistence_initialization_creates_table(initializing_async_persistence): + await asyncio.sleep(0.00001) + await initializing_async_persistence.initialize() + assert await initializing_async_persistence.list_app_ids("partition_key") == [] + + +async def test_async_persistence_is_initialized_false(initializing_async_persistence): + await asyncio.sleep(0.00001) + assert not await initializing_async_persistence.is_initialized() + + +async def test_async_persistence_is_initialized_true(initializing_async_persistence): + await asyncio.sleep(0.00001) + await initializing_async_persistence.initialize() + assert await initializing_async_persistence.is_initialized() + + +async def test_asyncsqlite_persistence_is_initialized_true_new_connection(tmp_path): + await asyncio.sleep(0.00001) + db_path = tmp_path / "test.db" + p = await AsyncSQLLitePersister.from_values(db_path=db_path, table_name="test_table") + await p.initialize() + p2 = await AsyncSQLLitePersister.from_values(db_path=db_path, table_name="test_table") + try: + assert await p.is_initialized() + assert await p2.is_initialized() + except Exception as e: + raise e + finally: + await p.close() + await p2.close() + + +async def test_async_save_and_load_from_sqlite_persister_end_to_end(tmp_path): + await asyncio.sleep(0.00001) + + @action(reads=[], writes=["prompt", "chat_history"]) + async def dummy_input(state: State) -> Tuple[dict, State]: + await asyncio.sleep(0.0001) + if state["chat_history"]: + new = state["chat_history"][-1] + 1 + else: + new = 1 + return ( + {"prompt": "PROMPT"}, + state.update(prompt="PROMPT").append(chat_history=new), + ) + + @action(reads=["chat_history"], writes=["response", "chat_history"]) + async def dummy_response(state: State) -> Tuple[dict, State]: + await asyncio.sleep(0.0001) + if state["chat_history"]: + new = state["chat_history"][-1] + 1 + else: + new = 1 + return ( + {"response": "RESPONSE"}, + state.update(response="RESPONSE").append(chat_history=new), + ) + + db_path = tmp_path / "test.db" + sqlite_persister = await AsyncSQLLitePersister.from_values( + db_path=db_path, table_name="test_table" + ) + await sqlite_persister.initialize() + app = await ( + ApplicationBuilder() + .with_actions(dummy_input, dummy_response) + .with_transitions(("dummy_input", "dummy_response"), ("dummy_response", "dummy_input")) + .initialize_from( + initializer=sqlite_persister, + resume_at_next_action=True, + default_state={"chat_history": []}, + default_entrypoint="dummy_input", + ) + .with_state_persister(sqlite_persister) + .with_identifiers(app_id="test_1", partition_key="sqlite") + .abuild() + ) + + try: + *_, state = await app.arun(halt_after=["dummy_response"]) + assert state["chat_history"][0] == 1 + assert state["chat_history"][1] == 2 + del app + except Exception as e: + raise e + finally: + await sqlite_persister.close() + del sqlite_persister + + sqlite_persister_2 = await AsyncSQLLitePersister.from_values( + db_path=db_path, table_name="test_table" + ) + await sqlite_persister_2.initialize() + new_app = await ( + ApplicationBuilder() + .with_actions(dummy_input, dummy_response) + .with_transitions(("dummy_input", "dummy_response"), ("dummy_response", "dummy_input")) + .initialize_from( + initializer=sqlite_persister_2, + resume_at_next_action=True, + default_state={"chat_history": []}, + default_entrypoint="dummy_input", + ) + .with_state_persister(sqlite_persister_2) + .with_identifiers(app_id="test_1", partition_key="sqlite") + .abuild() + ) + + try: + assert new_app.state["chat_history"][0] == 1 + assert new_app.state["chat_history"][1] == 2 + + *_, state = await new_app.arun(halt_after=["dummy_response"]) + assert state["chat_history"][2] == 3 + assert state["chat_history"][3] == 4 + except Exception as e: + raise e + finally: + await sqlite_persister_2.close() diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index b309a241..23fd87af 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -3,12 +3,15 @@ see failures in these tests, you should make a unit test, demonstrate the failure there, then fix both in that test and the end-to-end test.""" import asyncio +import datetime import uuid from concurrent.futures import ThreadPoolExecutor from io import StringIO -from typing import Any, AsyncGenerator, Dict, Generator, Tuple +from typing import Any, AsyncGenerator, Dict, Generator, Literal, Optional, Tuple from unittest.mock import patch +import pytest + from burr.core import ( Action, ApplicationBuilder, @@ -369,3 +372,114 @@ def echo(state: State) -> Tuple[dict, State]: format="png", ) assert result["response"] == prompt + + +async def test_async_save_and_load_from_persister_end_to_end(): + await asyncio.sleep(0.00001) + + @action(reads=[], writes=["prompt", "chat_history"]) + async def dummy_input(state: State) -> Tuple[dict, State]: + await asyncio.sleep(0.0001) + if state["chat_history"]: + new = state["chat_history"][-1] + 1 + else: + new = 1 + return ( + {"prompt": "PROMPT"}, + state.update(prompt="PROMPT").append(chat_history=new), + ) + + @action(reads=["chat_history"], writes=["response", "chat_history"]) + async def dummy_response(state: State) -> Tuple[dict, State]: + await asyncio.sleep(0.0001) + if state["chat_history"]: + new = state["chat_history"][-1] + 1 + else: + new = 1 + return ( + {"response": "RESPONSE"}, + state.update(response="RESPONSE").append(chat_history=new), + ) + + class AsyncDummyPersister(persistence.AsyncBaseStatePersister): + def __init__(self): + self.persisted_state = None + + async def load( + self, + partition_key: str, + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, + ) -> Optional[persistence.PersistedStateData]: + await asyncio.sleep(0.0001) + return self.persisted_state + + async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: + return [] + + async def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: State, + status: Literal["completed", "failed"], + **kwargs, + ): + await asyncio.sleep(0.0001) + self.persisted_state: persistence.PersistedStateData = { + "partition_key": partition_key or "", + "app_id": app_id, + "sequence_id": sequence_id, + "position": position, + "state": state, + "created_at": datetime.datetime.now().isoformat(), + "status": status, + } + + dummy_persister = AsyncDummyPersister() + app = await ( + ApplicationBuilder() + .with_actions(dummy_input, dummy_response) + .with_transitions(("dummy_input", "dummy_response"), ("dummy_response", "dummy_input")) + .initialize_from( + initializer=dummy_persister, + resume_at_next_action=True, + default_state={"chat_history": []}, + default_entrypoint="dummy_input", + ) + .with_state_persister(dummy_persister) + .abuild() + ) + + *_, state = await app.arun(halt_after=["dummy_response"]) + + assert state["chat_history"][0] == 1 + assert state["chat_history"][1] == 2 + del app + + new_app = await ( + ApplicationBuilder() + .with_actions(dummy_input, dummy_response) + .with_transitions(("dummy_input", "dummy_response"), ("dummy_response", "dummy_input")) + .initialize_from( + initializer=dummy_persister, + resume_at_next_action=True, + default_state={"chat_history": []}, + default_entrypoint="dummy_input", + ) + .with_state_persister(dummy_persister) + .abuild() + ) + + assert new_app.state["chat_history"][0] == 1 + assert new_app.state["chat_history"][1] == 2 + + *_, state = await new_app.arun(halt_after=["dummy_response"]) + assert state["chat_history"][2] == 3 + assert state["chat_history"][3] == 4 + + with pytest.raises(ValueError, match="The application was build with .abuild()"): + *_, state = new_app.run(halt_after=["dummy_response"])