From 6cfd3f9b6b13b617efeb7ee8a42e195be00b88db Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Fri, 10 Jan 2025 19:04:30 -0800 Subject: [PATCH] Moves aiosqlite/persister import to plugins --- burr/core/persistence.py | 240 ------------------- burr/integrations/persisters/b_aiosqlite.py | 253 ++++++++++++++++++++ 2 files changed, 253 insertions(+), 240 deletions(-) create mode 100644 burr/integrations/persisters/b_aiosqlite.py diff --git a/burr/core/persistence.py b/burr/core/persistence.py index d6574540..3b302ff1 100644 --- a/burr/core/persistence.py +++ b/burr/core/persistence.py @@ -6,8 +6,6 @@ from collections import defaultdict from typing import Any, Dict, Literal, Optional, TypedDict -import aiosqlite - from burr.common.types import BaseCopyable from burr.core import Action from burr.core.state import State, logger @@ -502,243 +500,6 @@ def __setstate__(self, state): ) -class AsyncSQLitePersister(AsyncBaseStatePersister, BaseCopyable): - """Class for asynchronous SQLite persistence of state. This is a simple implementation. - - SQLite is specifically single-threaded and `aiosqlite `_ - creates async support through multi-threading. This persister is mainly here for quick prototyping and testing; - we suggest to consider a different database with native async support for production. - - Note the third-party library `aiosqlite `_, - is maintained and considered stable considered stable: https://github.com/omnilib/aiosqlite/issues/309. - """ - - def copy(self) -> "Self": - return AsyncSQLitePersister( - db_path=self.db_path, - table_name=self.table_name, - serde_kwargs=self.serde_kwargs, - connect_kwargs=self._connect_kwargs, - ) - - PARTITION_KEY_DEFAULT = "" - - @classmethod - async def from_values( - cls, - db_path: str, - table_name: str = "burr_state", - serde_kwargs: dict = None, - connect_kwargs: dict = None, - ) -> "AsyncSQLitePersister": - """Creates a new instance of the AsyncSQLitePersister from passed in values. - - :param db_path: the path the DB will be stored. - :param table_name: the table name to store things under. - :param serde_kwargs: kwargs for state serialization/deserialization. - :param connect_kwargs: kwargs to pass to the aiosqlite.connect method. - :return: async sqlite persister instance with an open connection. You are responsible - for closing the connection yourself. - """ - connection = await aiosqlite.connect( - db_path, **connect_kwargs if connect_kwargs is not None else {} - ) - return cls(connection, table_name, serde_kwargs) - - def __init__( - self, - connection, - table_name: str = "burr_state", - serde_kwargs: dict = None, - ): - """Constructor. - - NOTE: you are responsible to handle closing of the connection / teardown manually. To help, - we provide a close() method. - - :param connection: the path the DB will be stored. - :param table_name: the table name to store things under. - :param serde_kwargs: kwargs for state serialization/deserialization. - """ - self.connection = connection - self.table_name = table_name - self.serde_kwargs = serde_kwargs or {} - self._initialized = False - - async def create_table_if_not_exists(self, table_name: str): - """Helper function to create the table where things are stored if it doesn't exist.""" - cursor = await self.connection.cursor() - await cursor.execute( - f""" - CREATE TABLE IF NOT EXISTS {table_name} ( - partition_key TEXT DEFAULT '{AsyncSQLitePersister.PARTITION_KEY_DEFAULT}', - app_id TEXT NOT NULL, - sequence_id INTEGER NOT NULL, - position TEXT NOT NULL, - status TEXT NOT NULL, - state TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - PRIMARY KEY (partition_key, app_id, sequence_id, position) - )""" - ) - await cursor.execute( - f""" - CREATE INDEX IF NOT EXISTS {table_name}_created_at_index ON {table_name} (created_at); - """ - ) - await self.connection.commit() - - async def initialize(self): - """Asynchronously creates the table if it doesn't exist""" - # Usage - await self.create_table_if_not_exists(self.table_name) - self._initialized = True - - async def is_initialized(self) -> bool: - """This checks to see if the table has been created in the database or not. - It defaults to using the initialized field, else queries the database to see if the table exists. - It then sets the initialized field to True if the table exists. - """ - if self._initialized: - return True - - cursor = await self.connection.cursor() - await cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (self.table_name,) - ) - self._initialized = await cursor.fetchone() is not None - return self._initialized - - async def list_app_ids(self, partition_key: Optional[str] = None, **kwargs) -> list[str]: - partition_key = ( - partition_key - if partition_key is not None - else AsyncSQLitePersister.PARTITION_KEY_DEFAULT - ) - - cursor = await self.connection.cursor() - await cursor.execute( - f"SELECT DISTINCT app_id FROM {self.table_name} " - f"WHERE partition_key = ? " - f"ORDER BY created_at DESC", - (partition_key,), - ) - app_ids = [row[0] for row in await cursor.fetchall()] - return app_ids - - async def load( - self, - partition_key: Optional[str], - app_id: Optional[str], - sequence_id: Optional[int] = None, - **kwargs, - ) -> Optional[PersistedStateData]: - """Asynchronously loads state for a given partition id. - - Depending on the parameters, this will return the last thing written, the last thing written for a given app_id, - or a specific sequence_id for a given app_id. - - :param partition_key: - :param app_id: - :param sequence_id: - :return: - """ - partition_key = ( - partition_key - if partition_key is not None - else AsyncSQLitePersister.PARTITION_KEY_DEFAULT - ) - logger.debug("Loading %s, %s, %s", partition_key, app_id, sequence_id) - cursor = await self.connection.cursor() - if app_id is None: - # get latest for all app_ids - await cursor.execute( - f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} " - f"WHERE partition_key = ? " - f"ORDER BY CREATED_AT DESC LIMIT 1", - (partition_key,), - ) - elif sequence_id is None: - await cursor.execute( - f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} " - f"WHERE partition_key = ? AND app_id = ? " - f"ORDER BY sequence_id DESC LIMIT 1", - (partition_key, app_id), - ) - else: - await cursor.execute( - f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} " - f"WHERE partition_key = ? AND app_id = ? AND sequence_id = ?", - (partition_key, app_id, sequence_id), - ) - row = await cursor.fetchone() - if row is None: - return None - _state = State.deserialize(json.loads(row[1]), **self.serde_kwargs) - return { - "partition_key": partition_key, - "app_id": row[3], - "sequence_id": row[2], - "position": row[0], - "state": _state, - "created_at": row[4], - "status": row[5], - } - - async def save( - self, - partition_key: Optional[str], - app_id: str, - sequence_id: int, - position: str, - state: State, - status: Literal["completed", "failed"], - **kwargs, - ): - """ - Asynchronously saves the state for a given app_id, sequence_id, and position. - - This method connects to the SQLite database, converts the state to a JSON string, and inserts a new record - into the table with the provided partition_key, app_id, sequence_id, position, and state. After the operation, - it commits the changes and closes the connection to the database. - - :param partition_key: The partition key. This could be None, but it's up to the persister to whether - that is a valid value it can handle. - :param app_id: The identifier for the app instance being recorded. - :param sequence_id: The state corresponding to a specific point in time. - :param position: The position in the sequence of states. - :param state: The state to be saved, an instance of the State class. - :param status: The status of this state, either "completed" or "failed". If "failed" the state is what it was - before the action was applied. - :return: None - """ - logger.debug( - "saving %s, %s, %s, %s, %s, %s", - partition_key, - app_id, - sequence_id, - position, - state, - status, - ) - partition_key = ( - partition_key - if partition_key is not None - else AsyncSQLitePersister.PARTITION_KEY_DEFAULT - ) - cursor = await self.connection.cursor() - json_state = json.dumps(state.serialize(**self.serde_kwargs)) - await cursor.execute( - f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) " - f"VALUES (?, ?, ?, ?, ?, ?)", - (partition_key, app_id, sequence_id, position, json_state, status), - ) - await self.connection.commit() - - async def close(self): - await self.connection.close() - - class InMemoryPersister(BaseStatePersister): """In-memory persister for testing purposes. This is not recommended for production use.""" @@ -846,7 +607,6 @@ async def save( SQLLitePersister = SQLitePersister -AsyncSQLLitePersister = AsyncSQLitePersister if __name__ == "__main__": s = SQLitePersister(db_path=".SQLite.db", table_name="test1") diff --git a/burr/integrations/persisters/b_aiosqlite.py b/burr/integrations/persisters/b_aiosqlite.py new file mode 100644 index 00000000..68730b5d --- /dev/null +++ b/burr/integrations/persisters/b_aiosqlite.py @@ -0,0 +1,253 @@ +import json +import logging +from typing import Literal, Optional + +import aiosqlite + +from burr.common.types import BaseCopyable +from burr.core import State +from burr.core.persistence import AsyncBaseStatePersister, PersistedStateData + +logger = logging.getLogger() + +try: + from typing import Self +except ImportError: + Self = None + + +class AsyncSQLitePersister(AsyncBaseStatePersister, BaseCopyable): + """Class for asynchronous SQLite persistence of state. This is a simple implementation. + + SQLite is specifically single-threaded and `aiosqlite `_ + creates async support through multi-threading. This persister is mainly here for quick prototyping and testing; + we suggest to consider a different database with native async support for production. + + Note the third-party library `aiosqlite `_, + is maintained and considered stable considered stable: https://github.com/omnilib/aiosqlite/issues/309. + """ + + def copy(self) -> "Self": + return AsyncSQLitePersister( + db_path=self.db_path, + table_name=self.table_name, + serde_kwargs=self.serde_kwargs, + connect_kwargs=self._connect_kwargs, + ) + + PARTITION_KEY_DEFAULT = "" + + @classmethod + async def from_values( + cls, + db_path: str, + table_name: str = "burr_state", + serde_kwargs: dict = None, + connect_kwargs: dict = None, + ) -> "AsyncSQLitePersister": + """Creates a new instance of the AsyncSQLitePersister from passed in values. + + :param db_path: the path the DB will be stored. + :param table_name: the table name to store things under. + :param serde_kwargs: kwargs for state serialization/deserialization. + :param connect_kwargs: kwargs to pass to the aiosqlite.connect method. + :return: async sqlite persister instance with an open connection. You are responsible + for closing the connection yourself. + """ + connection = await aiosqlite.connect( + db_path, **connect_kwargs if connect_kwargs is not None else {} + ) + return cls(connection, table_name, serde_kwargs) + + def __init__( + self, + connection, + table_name: str = "burr_state", + serde_kwargs: dict = None, + ): + """Constructor. + + NOTE: you are responsible to handle closing of the connection / teardown manually. To help, + we provide a close() method. + + :param connection: the path the DB will be stored. + :param table_name: the table name to store things under. + :param serde_kwargs: kwargs for state serialization/deserialization. + """ + self.connection = connection + self.table_name = table_name + self.serde_kwargs = serde_kwargs or {} + self._initialized = False + + async def create_table_if_not_exists(self, table_name: str): + """Helper function to create the table where things are stored if it doesn't exist.""" + cursor = await self.connection.cursor() + await cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + partition_key TEXT DEFAULT '{AsyncSQLitePersister.PARTITION_KEY_DEFAULT}', + app_id TEXT NOT NULL, + sequence_id INTEGER NOT NULL, + position TEXT NOT NULL, + status TEXT NOT NULL, + state TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (partition_key, app_id, sequence_id, position) + )""" + ) + await cursor.execute( + f""" + CREATE INDEX IF NOT EXISTS {table_name}_created_at_index ON {table_name} (created_at); + """ + ) + await self.connection.commit() + + async def initialize(self): + """Asynchronously creates the table if it doesn't exist""" + # Usage + await self.create_table_if_not_exists(self.table_name) + self._initialized = True + + async def is_initialized(self) -> bool: + """This checks to see if the table has been created in the database or not. + It defaults to using the initialized field, else queries the database to see if the table exists. + It then sets the initialized field to True if the table exists. + """ + if self._initialized: + return True + + cursor = await self.connection.cursor() + await cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (self.table_name,) + ) + self._initialized = await cursor.fetchone() is not None + return self._initialized + + async def list_app_ids(self, partition_key: Optional[str] = None, **kwargs) -> list[str]: + partition_key = ( + partition_key + if partition_key is not None + else AsyncSQLitePersister.PARTITION_KEY_DEFAULT + ) + + cursor = await self.connection.cursor() + await cursor.execute( + f"SELECT DISTINCT app_id FROM {self.table_name} " + f"WHERE partition_key = ? " + f"ORDER BY created_at DESC", + (partition_key,), + ) + app_ids = [row[0] for row in await cursor.fetchall()] + return app_ids + + async def load( + self, + partition_key: Optional[str], + app_id: Optional[str], + sequence_id: Optional[int] = None, + **kwargs, + ) -> Optional[PersistedStateData]: + """Asynchronously loads state for a given partition id. + + Depending on the parameters, this will return the last thing written, the last thing written for a given app_id, + or a specific sequence_id for a given app_id. + + :param partition_key: + :param app_id: + :param sequence_id: + :return: + """ + partition_key = ( + partition_key + if partition_key is not None + else AsyncSQLitePersister.PARTITION_KEY_DEFAULT + ) + logger.debug("Loading %s, %s, %s", partition_key, app_id, sequence_id) + cursor = await self.connection.cursor() + if app_id is None: + # get latest for all app_ids + await cursor.execute( + f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} " + f"WHERE partition_key = ? " + f"ORDER BY CREATED_AT DESC LIMIT 1", + (partition_key,), + ) + elif sequence_id is None: + await cursor.execute( + f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} " + f"WHERE partition_key = ? AND app_id = ? " + f"ORDER BY sequence_id DESC LIMIT 1", + (partition_key, app_id), + ) + else: + await cursor.execute( + f"SELECT position, state, sequence_id, app_id, created_at, status FROM {self.table_name} " + f"WHERE partition_key = ? AND app_id = ? AND sequence_id = ?", + (partition_key, app_id, sequence_id), + ) + row = await cursor.fetchone() + if row is None: + return None + _state = State.deserialize(json.loads(row[1]), **self.serde_kwargs) + return { + "partition_key": partition_key, + "app_id": row[3], + "sequence_id": row[2], + "position": row[0], + "state": _state, + "created_at": row[4], + "status": row[5], + } + + async def save( + self, + partition_key: Optional[str], + app_id: str, + sequence_id: int, + position: str, + state: State, + status: Literal["completed", "failed"], + **kwargs, + ): + """ + Asynchronously saves the state for a given app_id, sequence_id, and position. + + This method connects to the SQLite database, converts the state to a JSON string, and inserts a new record + into the table with the provided partition_key, app_id, sequence_id, position, and state. After the operation, + it commits the changes and closes the connection to the database. + + :param partition_key: The partition key. This could be None, but it's up to the persister to whether + that is a valid value it can handle. + :param app_id: The identifier for the app instance being recorded. + :param sequence_id: The state corresponding to a specific point in time. + :param position: The position in the sequence of states. + :param state: The state to be saved, an instance of the State class. + :param status: The status of this state, either "completed" or "failed". If "failed" the state is what it was + before the action was applied. + :return: None + """ + logger.debug( + "saving %s, %s, %s, %s, %s, %s", + partition_key, + app_id, + sequence_id, + position, + state, + status, + ) + partition_key = ( + partition_key + if partition_key is not None + else AsyncSQLitePersister.PARTITION_KEY_DEFAULT + ) + cursor = await self.connection.cursor() + json_state = json.dumps(state.serialize(**self.serde_kwargs)) + await cursor.execute( + f"INSERT INTO {self.table_name} (partition_key, app_id, sequence_id, position, state, status) " + f"VALUES (?, ?, ?, ?, ?, ?)", + (partition_key, app_id, sequence_id, position, json_state, status), + ) + await self.connection.commit() + + async def close(self): + await self.connection.close()