From c24cad41d361eb38b2b40f2c0c6dd86d61aa1c87 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Fri, 13 Dec 2024 14:43:31 -0800 Subject: [PATCH] Refactors Redis Persister (#471) This adds a new class RedisBasePersister that the old one inherits from. The reason to do this is that if someone wants to test things, then being able to inject a mock redis client makes this simpler. So refactoring to enable that while still being backwards compatible. Appropriately marked things as deprecated and tests have been updated. --- burr/integrations/persisters/b_redis.py | 82 +++++++++++++++---- docs/reference/persister.rst | 3 +- tests/integrations/persisters/test_b_redis.py | 15 +++- 3 files changed, 78 insertions(+), 22 deletions(-) diff --git a/burr/integrations/persisters/b_redis.py b/burr/integrations/persisters/b_redis.py index 65f38285..06fa1f45 100644 --- a/burr/integrations/persisters/b_redis.py +++ b/burr/integrations/persisters/b_redis.py @@ -16,15 +16,21 @@ logger = logging.getLogger(__name__) -class RedisPersister(persistence.BaseStatePersister): - """A class used to represent a Redis Persister. +class RedisBasePersister(persistence.BaseStatePersister): + """Main class for Redis persister. + + Use this class if you want to directly control injecting the Redis client. This class is responsible for persisting state data to a Redis database. It inherits from the BaseStatePersister class. + + Note: We didn't create the right constructor for the initial implementation of the RedisPersister class, + so this is an attempt to fix that in a backwards compatible way. """ - def __init__( - self, + @classmethod + def from_values( + cls, host: str, port: int, db: int, @@ -32,22 +38,28 @@ def __init__( serde_kwargs: dict = None, redis_client_kwargs: dict = None, namespace: str = None, + ) -> "RedisBasePersister": + """Creates a new instance of the RedisBasePersister from passed in values.""" + if redis_client_kwargs is None: + redis_client_kwargs = {} + connection = redis.Redis( + host=host, port=port, db=db, password=password, **redis_client_kwargs + ) + return cls(connection, serde_kwargs, namespace) + + def __init__( + self, + connection, + serde_kwargs: dict = None, + namespace: str = None, ): """Initializes the RedisPersister class. - :param host: - :param port: - :param db: - :param password: - :param serde_kwargs: - :param redis_client_kwargs: Additional keyword arguments to pass to the redis.Redis client. + :param connection: the redis connection object. + :param serde_kwargs: serialization and deserialization keyword arguments to pass to state SERDE. :param namespace: The name of the project to optionally use in the key prefix. """ - if redis_client_kwargs is None: - redis_client_kwargs = {} - self.connection = redis.Redis( - host=host, port=port, db=db, password=password, **redis_client_kwargs - ) + self.connection = connection self.serde_kwargs = serde_kwargs or {} self.namespace = namespace if namespace else "" @@ -149,9 +161,45 @@ def __del__(self): self.connection.close() +class RedisPersister(RedisBasePersister): + """A class used to represent a Redis Persister. + + This class is deprecated. Use RedisBasePersister.from_values() instead. + """ + + def __init__( + self, + host: str, + port: int, + db: int, + password: str = None, + serde_kwargs: dict = None, + redis_client_kwargs: dict = None, + namespace: str = None, + ): + """Initializes the RedisPersister class. + + This is deprecated. Use RedisBasePersister.from_values() instead. + + :param host: + :param port: + :param db: + :param password: + :param serde_kwargs: + :param redis_client_kwargs: Additional keyword arguments to pass to the redis.Redis client. + :param namespace: The name of the project to optionally use in the key prefix. + """ + if redis_client_kwargs is None: + redis_client_kwargs = {} + connection = redis.Redis( + host=host, port=port, db=db, password=password, **redis_client_kwargs + ) + super(RedisPersister, self).__init__(connection, serde_kwargs, namespace) + + if __name__ == "__main__": - # test the RedisPersister class - persister = RedisPersister("localhost", 6379, 0) + # test the RedisBasePersister class + persister = RedisBasePersister.from_values("localhost", 6379, 0) persister.initialize() persister.save("pk", "app_id", 2, "pos", state.State({"a": 1, "b": 2}), "completed") diff --git a/docs/reference/persister.rst b/docs/reference/persister.rst index 85ff6fb5..3b37a23e 100644 --- a/docs/reference/persister.rst +++ b/docs/reference/persister.rst @@ -48,8 +48,7 @@ Currently we support the following, although we highly recommend you contribute .. automethod:: __init__ - -.. autoclass:: burr.integrations.persisters.b_redis.RedisPersister +.. autoclass:: burr.integrations.persisters.b_redis.RedisBasePersister :members: .. automethod:: __init__ diff --git a/tests/integrations/persisters/test_b_redis.py b/tests/integrations/persisters/test_b_redis.py index ddfe7355..c9aa6b58 100644 --- a/tests/integrations/persisters/test_b_redis.py +++ b/tests/integrations/persisters/test_b_redis.py @@ -3,7 +3,7 @@ import pytest from burr.core import state -from burr.integrations.persisters.b_redis import RedisPersister +from burr.integrations.persisters.b_redis import RedisBasePersister, RedisPersister if not os.environ.get("BURR_CI_INTEGRATION_TESTS") == "true": pytest.skip("Skipping integration tests", allow_module_level=True) @@ -11,14 +11,14 @@ @pytest.fixture def redis_persister(): - persister = RedisPersister(host="localhost", port=6379, db=0) + persister = RedisBasePersister.from_values(host="localhost", port=6379, db=0) yield persister persister.connection.close() @pytest.fixture def redis_persister_with_ns(): - persister = RedisPersister(host="localhost", port=6379, db=0, namespace="test") + persister = RedisBasePersister.from_values(host="localhost", port=6379, db=0, namespace="test") yield persister persister.connection.close() @@ -61,3 +61,12 @@ def test_list_app_ids_with_ns(redis_persister_with_ns): def test_load_nonexistent_key_with_ns(redis_persister_with_ns): state_data = redis_persister_with_ns.load("pk", "nonexistent_key") assert state_data is None + + +def test_redis_persister_class_backwards_compatible(): + """Tests that the RedisPersister class is still backwards compatible.""" + persister = RedisPersister(host="localhost", port=6379, db=0, namespace="backwardscompatible") + persister.save("pk", "app_id", 2, "pos", state.State({"a": 4, "b": 5}), "completed") + data = persister.load("pk", "app_id", 2) + assert data["state"].get_all() == {"a": 4, "b": 5} + persister.connection.close()