Skip to content

Commit

Permalink
Adds SERDE for persisters
Browse files Browse the repository at this point in the history
The persisters can now be pickled and unpickled.
This means that they can be properly serialized across
process boundaries.
  • Loading branch information
skrawcz committed Dec 26, 2024
1 parent 402623e commit 3b70072
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 0 deletions.
21 changes: 21 additions & 0 deletions burr/integrations/persisters/b_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,27 @@ def save(
def __del__(self):
self.client.close()

def __getstate__(self) -> dict:
state = self.__dict__.copy()
state["connection_params"] = {
"uri": self.client.address[0],
"port": self.client.address[1],
"db_name": self.db.name,
"collection_name": self.collection.name,
}
del state["client"]
del state["db"]
del state["collection"]
return state

def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume MongoClient.
self.client = MongoClient(connection_params["uri"], connection_params["port"])
self.db = self.client[connection_params["db_name"]]
self.collection = self.db[connection_params["collection_name"]]
self.__dict__.update(state)


class MongoDBPersister(MongoDBBasePersister):
"""A class used to represent a MongoDB Persister.
Expand Down
20 changes: 20 additions & 0 deletions burr/integrations/persisters/b_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,26 @@ def save(
def __del__(self):
self.connection.close()

def __getstate__(self) -> dict:
state = self.__dict__.copy()
if not hasattr(self.connection, "connection_pool"):
logger.warning("Redis connection is not serializable.")
return state
state["connection_params"] = {
"host": self.connection.connection_pool.connection_kwargs["host"],
"port": self.connection.connection_pool.connection_kwargs["port"],
"db": self.connection.connection_pool.connection_kwargs["db"],
"password": self.connection.connection_pool.connection_kwargs["password"],
}
del state["connection"]
return state

def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume normal redis client.
self.connection = redis.Redis(**connection_params)
self.__dict__.update(state)


class RedisPersister(RedisBasePersister):
"""A class used to represent a Redis Persister.
Expand Down
23 changes: 23 additions & 0 deletions burr/integrations/persisters/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,29 @@ def __del__(self):
# closes connection at end when things are being shutdown.
self.connection.close()

def __getstate__(self) -> dict:
state = self.__dict__.copy()
if not hasattr(self.connection, "info"):
logger.warning(
"Postgresql information for connection object not available. Cannot serialize persister."
)
return state
state["connection_params"] = {
"dbname": self.connection.info.dbname,
"user": self.connection.info.user,
"password": self.connection.info.password,
"host": self.connection.info.host,
"port": self.connection.info.port,
}
del state["connection"]
return state

def __setstate__(self, state: dict):
connection_params = state.pop("connection_params")
# we assume normal psycopg2 client.
self.connection = psycopg2.connect(**connection_params)
self.__dict__.update(state)


if __name__ == "__main__":
# test the PostgreSQLPersister class
Expand Down
19 changes: 19 additions & 0 deletions tests/integrations/persisters/test_b_mongodb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle

import pytest

Expand Down Expand Up @@ -46,3 +47,21 @@ def test_backwards_compatible_persister():
assert data["state"].get_all() == {"a": 5, "b": 5}

persister.collection.drop()


def test_serialization_with_pickle(mongodb_persister):
# Save some state
mongodb_persister.save(
"pk", "app_id_serde", 1, "pos", state.State({"a": 1, "b": 2}), "completed"
)

# Serialize the persister
serialized_persister = pickle.dumps(mongodb_persister)

# Deserialize the persister
deserialized_persister = pickle.loads(serialized_persister)

# Load the state from the deserialized persister
data = deserialized_persister.load("pk", "app_id_serde", 1)

assert data["state"].get_all() == {"a": 1, "b": 2}
19 changes: 19 additions & 0 deletions tests/integrations/persisters/test_b_redis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle

import pytest

Expand Down Expand Up @@ -70,3 +71,21 @@ def test_redis_persister_class_backwards_compatible():
data = persister.load("pk", "app_id", 2)
assert data["state"].get_all() == {"a": 4, "b": 5}
persister.connection.close()


def test_serialization_with_pickle(redis_persister_with_ns):
# Save some state
redis_persister_with_ns.save(
"pk", "app_id_serde", 1, "pos", state.State({"a": 1, "b": 2}), "completed"
)

# Serialize the persister
serialized_persister = pickle.dumps(redis_persister_with_ns)

# Deserialize the persister
deserialized_persister = pickle.loads(serialized_persister)

# Load the state from the deserialized persister
data = deserialized_persister.load("pk", "app_id_serde", 1)

assert data["state"].get_all() == {"a": 1, "b": 2}
19 changes: 19 additions & 0 deletions tests/integrations/persisters/test_postgresql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle

import pytest

Expand Down Expand Up @@ -67,3 +68,21 @@ def test_is_initialized_false():
table_name="testtable2",
)
assert not persister.is_initialized()


def test_serialization_with_pickle(postgresql_persister):
# Save some state
postgresql_persister.save(
"pk", "app_id_serde", 1, "pos", state.State({"a": 1, "b": 2}), "completed"
)

# Serialize the persister
serialized_persister = pickle.dumps(postgresql_persister)

# Deserialize the persister
deserialized_persister = pickle.loads(serialized_persister)

# Load the state from the deserialized persister
data = deserialized_persister.load("pk", "app_id_serde", 1)

assert data["state"].get_all() == {"a": 1, "b": 2}

0 comments on commit 3b70072

Please sign in to comment.