Skip to content

Commit

Permalink
Refactors Redis Persister (#471)
Browse files Browse the repository at this point in the history
This adds a new class RedisBasePersister that the old one inherits from.

The reason to do this is that if someone wants to test things, then being
able to inject a mock redis client makes this simpler. So refactoring
to enable that while still being backwards compatible.

Appropriately marked things as deprecated and tests have been updated.
  • Loading branch information
skrawcz authored Dec 13, 2024
1 parent bc3af0d commit c24cad4
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 22 deletions.
82 changes: 65 additions & 17 deletions burr/integrations/persisters/b_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,50 @@
logger = logging.getLogger(__name__)


class RedisPersister(persistence.BaseStatePersister):
"""A class used to represent a Redis Persister.
class RedisBasePersister(persistence.BaseStatePersister):
"""Main class for Redis persister.
Use this class if you want to directly control injecting the Redis client.
This class is responsible for persisting state data to a Redis database.
It inherits from the BaseStatePersister class.
Note: We didn't create the right constructor for the initial implementation of the RedisPersister class,
so this is an attempt to fix that in a backwards compatible way.
"""

def __init__(
self,
@classmethod
def from_values(
cls,
host: str,
port: int,
db: int,
password: str = None,
serde_kwargs: dict = None,
redis_client_kwargs: dict = None,
namespace: str = None,
) -> "RedisBasePersister":
"""Creates a new instance of the RedisBasePersister from passed in values."""
if redis_client_kwargs is None:
redis_client_kwargs = {}
connection = redis.Redis(
host=host, port=port, db=db, password=password, **redis_client_kwargs
)
return cls(connection, serde_kwargs, namespace)

def __init__(
self,
connection,
serde_kwargs: dict = None,
namespace: str = None,
):
"""Initializes the RedisPersister class.
:param host:
:param port:
:param db:
:param password:
:param serde_kwargs:
:param redis_client_kwargs: Additional keyword arguments to pass to the redis.Redis client.
:param connection: the redis connection object.
:param serde_kwargs: serialization and deserialization keyword arguments to pass to state SERDE.
:param namespace: The name of the project to optionally use in the key prefix.
"""
if redis_client_kwargs is None:
redis_client_kwargs = {}
self.connection = redis.Redis(
host=host, port=port, db=db, password=password, **redis_client_kwargs
)
self.connection = connection
self.serde_kwargs = serde_kwargs or {}
self.namespace = namespace if namespace else ""

Expand Down Expand Up @@ -149,9 +161,45 @@ def __del__(self):
self.connection.close()


class RedisPersister(RedisBasePersister):
"""A class used to represent a Redis Persister.
This class is deprecated. Use RedisBasePersister.from_values() instead.
"""

def __init__(
self,
host: str,
port: int,
db: int,
password: str = None,
serde_kwargs: dict = None,
redis_client_kwargs: dict = None,
namespace: str = None,
):
"""Initializes the RedisPersister class.
This is deprecated. Use RedisBasePersister.from_values() instead.
:param host:
:param port:
:param db:
:param password:
:param serde_kwargs:
:param redis_client_kwargs: Additional keyword arguments to pass to the redis.Redis client.
:param namespace: The name of the project to optionally use in the key prefix.
"""
if redis_client_kwargs is None:
redis_client_kwargs = {}
connection = redis.Redis(
host=host, port=port, db=db, password=password, **redis_client_kwargs
)
super(RedisPersister, self).__init__(connection, serde_kwargs, namespace)


if __name__ == "__main__":
# test the RedisPersister class
persister = RedisPersister("localhost", 6379, 0)
# test the RedisBasePersister class
persister = RedisBasePersister.from_values("localhost", 6379, 0)

persister.initialize()
persister.save("pk", "app_id", 2, "pos", state.State({"a": 1, "b": 2}), "completed")
Expand Down
3 changes: 1 addition & 2 deletions docs/reference/persister.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ Currently we support the following, although we highly recommend you contribute

.. automethod:: __init__


.. autoclass:: burr.integrations.persisters.b_redis.RedisPersister
.. autoclass:: burr.integrations.persisters.b_redis.RedisBasePersister
:members:

.. automethod:: __init__
Expand Down
15 changes: 12 additions & 3 deletions tests/integrations/persisters/test_b_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
import pytest

from burr.core import state
from burr.integrations.persisters.b_redis import RedisPersister
from burr.integrations.persisters.b_redis import RedisBasePersister, RedisPersister

if not os.environ.get("BURR_CI_INTEGRATION_TESTS") == "true":
pytest.skip("Skipping integration tests", allow_module_level=True)


@pytest.fixture
def redis_persister():
persister = RedisPersister(host="localhost", port=6379, db=0)
persister = RedisBasePersister.from_values(host="localhost", port=6379, db=0)
yield persister
persister.connection.close()


@pytest.fixture
def redis_persister_with_ns():
persister = RedisPersister(host="localhost", port=6379, db=0, namespace="test")
persister = RedisBasePersister.from_values(host="localhost", port=6379, db=0, namespace="test")
yield persister
persister.connection.close()

Expand Down Expand Up @@ -61,3 +61,12 @@ def test_list_app_ids_with_ns(redis_persister_with_ns):
def test_load_nonexistent_key_with_ns(redis_persister_with_ns):
state_data = redis_persister_with_ns.load("pk", "nonexistent_key")
assert state_data is None


def test_redis_persister_class_backwards_compatible():
"""Tests that the RedisPersister class is still backwards compatible."""
persister = RedisPersister(host="localhost", port=6379, db=0, namespace="backwardscompatible")
persister.save("pk", "app_id", 2, "pos", state.State({"a": 4, "b": 5}), "completed")
data = persister.load("pk", "app_id", 2)
assert data["state"].get_all() == {"a": 4, "b": 5}
persister.connection.close()

0 comments on commit c24cad4

Please sign in to comment.