diff --git a/burr/integrations/persisters/b_mongodb.py b/burr/integrations/persisters/b_mongodb.py index 1aaa8bea..0481db17 100644 --- a/burr/integrations/persisters/b_mongodb.py +++ b/burr/integrations/persisters/b_mongodb.py @@ -130,6 +130,27 @@ def save( def __del__(self): self.client.close() + def __getstate__(self) -> dict: + state = self.__dict__.copy() + state["connection_params"] = { + "uri": self.client.address[0], + "port": self.client.address[1], + "db_name": self.db.name, + "collection_name": self.collection.name, + } + del state["client"] + del state["db"] + del state["collection"] + return state + + def __setstate__(self, state: dict): + connection_params = state.pop("connection_params") + # we assume MongoClient. + self.client = MongoClient(connection_params["uri"], connection_params["port"]) + self.db = self.client[connection_params["db_name"]] + self.collection = self.db[connection_params["collection_name"]] + self.__dict__.update(state) + class MongoDBPersister(MongoDBBasePersister): """A class used to represent a MongoDB Persister. diff --git a/burr/integrations/persisters/b_redis.py b/burr/integrations/persisters/b_redis.py index 06fa1f45..25231bf9 100644 --- a/burr/integrations/persisters/b_redis.py +++ b/burr/integrations/persisters/b_redis.py @@ -160,6 +160,26 @@ def save( def __del__(self): self.connection.close() + def __getstate__(self) -> dict: + state = self.__dict__.copy() + if not hasattr(self.connection, "connection_pool"): + logger.warning("Redis connection is not serializable.") + return state + state["connection_params"] = { + "host": self.connection.connection_pool.connection_kwargs["host"], + "port": self.connection.connection_pool.connection_kwargs["port"], + "db": self.connection.connection_pool.connection_kwargs["db"], + "password": self.connection.connection_pool.connection_kwargs["password"], + } + del state["connection"] + return state + + def __setstate__(self, state: dict): + connection_params = state.pop("connection_params") + # we assume normal redis client. + self.connection = redis.Redis(**connection_params) + self.__dict__.update(state) + class RedisPersister(RedisBasePersister): """A class used to represent a Redis Persister. diff --git a/burr/integrations/persisters/postgresql.py b/burr/integrations/persisters/postgresql.py index c5d152c2..86f50f66 100644 --- a/burr/integrations/persisters/postgresql.py +++ b/burr/integrations/persisters/postgresql.py @@ -246,6 +246,29 @@ def __del__(self): # closes connection at end when things are being shutdown. self.connection.close() + def __getstate__(self) -> dict: + state = self.__dict__.copy() + if not hasattr(self.connection, "info"): + logger.warning( + "Postgresql information for connection object not available. Cannot serialize persister." + ) + return state + state["connection_params"] = { + "dbname": self.connection.info.dbname, + "user": self.connection.info.user, + "password": self.connection.info.password, + "host": self.connection.info.host, + "port": self.connection.info.port, + } + del state["connection"] + return state + + def __setstate__(self, state: dict): + connection_params = state.pop("connection_params") + # we assume normal psycopg2 client. + self.connection = psycopg2.connect(**connection_params) + self.__dict__.update(state) + if __name__ == "__main__": # test the PostgreSQLPersister class diff --git a/tests/integrations/persisters/test_b_mongodb.py b/tests/integrations/persisters/test_b_mongodb.py index 7a665aa9..c9d13d00 100644 --- a/tests/integrations/persisters/test_b_mongodb.py +++ b/tests/integrations/persisters/test_b_mongodb.py @@ -1,4 +1,5 @@ import os +import pickle import pytest @@ -46,3 +47,21 @@ def test_backwards_compatible_persister(): assert data["state"].get_all() == {"a": 5, "b": 5} persister.collection.drop() + + +def test_serialization_with_pickle(mongodb_persister): + # Save some state + mongodb_persister.save( + "pk", "app_id_serde", 1, "pos", state.State({"a": 1, "b": 2}), "completed" + ) + + # Serialize the persister + serialized_persister = pickle.dumps(mongodb_persister) + + # Deserialize the persister + deserialized_persister = pickle.loads(serialized_persister) + + # Load the state from the deserialized persister + data = deserialized_persister.load("pk", "app_id_serde", 1) + + assert data["state"].get_all() == {"a": 1, "b": 2} diff --git a/tests/integrations/persisters/test_b_redis.py b/tests/integrations/persisters/test_b_redis.py index c9aa6b58..80551cec 100644 --- a/tests/integrations/persisters/test_b_redis.py +++ b/tests/integrations/persisters/test_b_redis.py @@ -1,4 +1,5 @@ import os +import pickle import pytest @@ -70,3 +71,21 @@ def test_redis_persister_class_backwards_compatible(): data = persister.load("pk", "app_id", 2) assert data["state"].get_all() == {"a": 4, "b": 5} persister.connection.close() + + +def test_serialization_with_pickle(redis_persister_with_ns): + # Save some state + redis_persister_with_ns.save( + "pk", "app_id_serde", 1, "pos", state.State({"a": 1, "b": 2}), "completed" + ) + + # Serialize the persister + serialized_persister = pickle.dumps(redis_persister_with_ns) + + # Deserialize the persister + deserialized_persister = pickle.loads(serialized_persister) + + # Load the state from the deserialized persister + data = deserialized_persister.load("pk", "app_id_serde", 1) + + assert data["state"].get_all() == {"a": 1, "b": 2} diff --git a/tests/integrations/persisters/test_postgresql.py b/tests/integrations/persisters/test_postgresql.py index c45c6fa0..ea4f0794 100644 --- a/tests/integrations/persisters/test_postgresql.py +++ b/tests/integrations/persisters/test_postgresql.py @@ -1,4 +1,5 @@ import os +import pickle import pytest @@ -67,3 +68,21 @@ def test_is_initialized_false(): table_name="testtable2", ) assert not persister.is_initialized() + + +def test_serialization_with_pickle(postgresql_persister): + # Save some state + postgresql_persister.save( + "pk", "app_id_serde", 1, "pos", state.State({"a": 1, "b": 2}), "completed" + ) + + # Serialize the persister + serialized_persister = pickle.dumps(postgresql_persister) + + # Deserialize the persister + deserialized_persister = pickle.loads(serialized_persister) + + # Load the state from the deserialized persister + data = deserialized_persister.load("pk", "app_id_serde", 1) + + assert data["state"].get_all() == {"a": 1, "b": 2}