diff --git a/dev_requirements.txt b/dev_requirements.txt index a8da4b49cd..eae75a3539 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -17,3 +17,5 @@ uvloop vulture>=2.3.0 wheel>=0.30.0 numpy>=1.24.0 +requests>=2.23.0 +aiohttp>=3.0.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index dd78bb6a2c..8c7164ba57 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -134,6 +134,13 @@ def pytest_addoption(parser): help="Name of the Redis master service that the sentinels are monitoring", ) + parser.addoption( + "--endpoints-config", + action="store", + default="endpoints.json", + help="Path to the Redis endpoints configuration file", + ) + def _get_info(redis_url): client = redis.Redis.from_url(redis_url) diff --git a/tests/scenario/__init__.py b/tests/scenario/__init__.py new file mode 100644 index 0000000000..2b4bd0c423 --- /dev/null +++ b/tests/scenario/__init__.py @@ -0,0 +1,57 @@ +import dataclasses +import json +import os.path +from typing import List +from urllib.parse import urlparse + +import pytest + + +@dataclasses.dataclass +class Endpoint: + bdb_id: int + username: str + password: str + tls: bool + endpoints: List[str] + + @property + def url(self): + parsed_url = urlparse(self.endpoints[0]) + + if self.tls: + parsed_url = parsed_url._replace(scheme="rediss") + + domain = parsed_url.netloc.split("@")[-1] + domain = f"{self.username}:{self.password}@{domain}" + + parsed_url = parsed_url._replace(netloc=domain) + + return parsed_url.geturl() + + @classmethod + def from_dict(cls, data: dict): + field_names = set(f.name for f in dataclasses.fields(cls)) + return cls(**{k: v for k, v in data.items() if k in field_names}) + + +def get_endpoint(request: pytest.FixtureRequest, endpoint_name: str) -> Endpoint: + endpoints_config_path = request.config.getoption("--endpoints-config") + + if not (endpoints_config_path and os.path.exists(endpoints_config_path)): + raise FileNotFoundError( + f"Endpoints config file not found: {endpoints_config_path}" + ) + + try: + with open(endpoints_config_path, "r") as f: + endpoints_config = json.load(f) + except Exception as e: + raise ValueError( + f"Failed to load endpoints config file: {endpoints_config_path}" + ) from e + + if not (isinstance(endpoints_config, dict) and endpoint_name in endpoints_config): + raise ValueError(f"Endpoint not found in config: {endpoint_name}") + + return Endpoint.from_dict(endpoints_config.get(endpoint_name)) diff --git a/tests/scenario/fake_app.py b/tests/scenario/fake_app.py new file mode 100644 index 0000000000..e643f52591 --- /dev/null +++ b/tests/scenario/fake_app.py @@ -0,0 +1,63 @@ +import multiprocessing +import typing +from multiprocessing import Event as PEvent +from multiprocessing import Process +from threading import Event, Thread +from unittest.mock import patch + +from redis import Redis + + +class FakeApp: + + def __init__(self, client: Redis, logic: typing.Callable[[Redis], None]): + self.client = client + self.logic = logic + self.disconnects = 0 + + def run(self) -> (Event, Thread): + e = Event() + t = Thread(target=self._run_logic, args=(e,)) + t.start() + return e, t + + def _run_logic(self, e: Event): + with patch.object( + self.client, "_disconnect_raise", wraps=self.client._disconnect_raise + ) as spy: + while not e.is_set(): + self.logic(self.client) + + self.disconnects = spy.call_count + + +class FakeSubscriber: + + def __init__(self, client: Redis, logic: typing.Callable[[dict], None]): + self.client = client + self.logic = logic + self.disconnects = multiprocessing.Value("i", 0) + + def run(self, channel: str) -> (PEvent, Process): + e, started = PEvent(), PEvent() + p = Process(target=self._run_logic, args=(e, started, channel)) + p.start() + return e, started, p + + def _run_logic(self, should_stop: PEvent, started: PEvent, channel: str): + pubsub = self.client.pubsub() + + with patch.object( + pubsub, "_disconnect_raise_connect", wraps=pubsub._disconnect_raise_connect + ) as spy_pubsub: + pubsub.subscribe(channel) + + started.set() + + while not should_stop.is_set(): + message = pubsub.get_message(ignore_subscribe_messages=True, timeout=1) + + if message: + self.logic(message) + + self.disconnects.value = spy_pubsub.call_count diff --git a/tests/scenario/fault_injection_client.py b/tests/scenario/fault_injection_client.py new file mode 100644 index 0000000000..750298779d --- /dev/null +++ b/tests/scenario/fault_injection_client.py @@ -0,0 +1,39 @@ +import requests + + +class TriggeredAction: + + def __init__(self, client: "FaultInjectionClient", data: dict): + self.client = client + self.action_id = data["action_id"] + self.data = data + + def refresh(self): + self.data = self.client.get_action(self.action_id) + + @property + def status(self): + if "status" not in self.data: + return "pending" + return self.data["status"] + + def wait_until_complete(self): + while self.status not in ("success", "failed"): + self.refresh() + return self.status + + +class FaultInjectionClient: + def __init__(self, base_url: str = "http://127.0.0.1:20324"): + self.base_url = base_url + + def trigger_action(self, action_type: str, parameters: dict): + response = requests.post( + f"{self.base_url}/action", + json={"type": action_type, "parameters": parameters}, + ) + return TriggeredAction(self, response.json()) + + def get_action(self, action_id: str): + response = requests.get(f"{self.base_url}/action/{action_id}") + return response.json() diff --git a/tests/scenario/test_connection_interruptions.py b/tests/scenario/test_connection_interruptions.py new file mode 100644 index 0000000000..a1ce067108 --- /dev/null +++ b/tests/scenario/test_connection_interruptions.py @@ -0,0 +1,157 @@ +import multiprocessing +import time +from typing import List + +import pytest +from redis import BusyLoadingError, Redis +from redis.backoff import ExponentialBackoff +from redis.exceptions import ConnectionError as RedisConnectionError +from redis.exceptions import TimeoutError as RedisTimeoutError +from redis.retry import Retry + +from ..conftest import _get_client +from . import Endpoint, get_endpoint +from .fake_app import FakeApp, FakeSubscriber +from .fault_injection_client import FaultInjectionClient + + +@pytest.fixture +def endpoint_name(): + return "re-standalone" + + +@pytest.fixture +def endpoint(request: pytest.FixtureRequest, endpoint_name: str): + try: + return get_endpoint(request, endpoint_name) + except FileNotFoundError as e: + pytest.skip( + f"Skipping scenario test because endpoints file is missing: {str(e)}" + ) + + +@pytest.fixture +def clients(request: pytest.FixtureRequest, endpoint: Endpoint): + # Use Recommended settings + retry = Retry(ExponentialBackoff(base=1), 3) + + clients = [] + + for _ in range(2): + r = _get_client( + Redis, + request, + decode_responses=True, + from_url=endpoint.url, + retry=retry, + retry_on_error=[ + BusyLoadingError, + RedisConnectionError, + RedisTimeoutError, + # FIXME: This is a workaround for a bug in redis-py + # https://github.com/redis/redis-py/issues/3203 + ConnectionError, + TimeoutError, + ], + ) + r.flushdb() + clients.append(r) + return clients + + +@pytest.fixture +def fault_injection_client(request: pytest.FixtureRequest): + return FaultInjectionClient() + + +@pytest.mark.parametrize("action", ("dmc_restart", "network_failure")) +def test_connection_interruptions( + clients: List[Redis], + endpoint: Endpoint, + fault_injection_client: FaultInjectionClient, + action: str, +): + client = clients.pop() + app = FakeApp(client, lambda c: c.set("foo", "bar")) + + stop_app, thread = app.run() + + triggered_action = fault_injection_client.trigger_action( + action, {"bdb_id": endpoint.bdb_id} + ) + + triggered_action.wait_until_complete() + + stop_app.set() + thread.join() + + if triggered_action.status == "failed": + pytest.fail(f"Action failed: {triggered_action.data['error']}") + + assert app.disconnects > 0, "Client did not disconnect" + + +@pytest.mark.parametrize("action", ("dmc_restart", "network_failure")) +def test_pubsub_with_connection_interruptions( + clients: List[Redis], + endpoint: Endpoint, + fault_injection_client: FaultInjectionClient, + action: str, +): + channel = "test" + + # Subscriber is executed in a separate process to ensure it reacts + # to the disconnection at the same time as the publisher + with multiprocessing.Manager() as manager: + received_messages = manager.list() + + def read_message(message): + nonlocal received_messages + if message and message["type"] == "message": + received_messages.append(message["data"]) + + subscriber_client = clients.pop() + subscriber = FakeSubscriber(subscriber_client, read_message) + stop_subscriber, subscriber_started, subscriber_t = subscriber.run(channel) + + # Allow subscriber subscribe to the channel + subscriber_started.wait(timeout=5) + + messages_sent = 0 + + def publish_message(c): + nonlocal messages_sent, channel + messages_sent += 1 + c.publish(channel, messages_sent) + + publisher_client = clients.pop() + publisher = FakeApp(publisher_client, publish_message) + stop_publisher, publisher_t = publisher.run() + + triggered_action = fault_injection_client.trigger_action( + action, {"bdb_id": endpoint.bdb_id} + ) + + triggered_action.wait_until_complete() + last_message_sent_after_trigger = messages_sent + + time.sleep(3) # Wait for the publisher to send more messages + + stop_publisher.set() + publisher_t.join() + + stop_subscriber.set() + subscriber_t.join() + + assert publisher.disconnects > 0 + assert subscriber.disconnects.value > 0 + + if triggered_action.status == "failed": + pytest.fail(f"Action failed: {triggered_action.data['error']}") + + assert ( + last_message_sent_after_trigger < messages_sent + ), "No messages were sent after the failure" + assert ( + int(received_messages[-1]) == messages_sent + ), "Not all messages were received" diff --git a/tests/test_asyncio/scenario/__init__.py b/tests/test_asyncio/scenario/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/scenario/fake_app.py b/tests/test_asyncio/scenario/fake_app.py new file mode 100644 index 0000000000..ea4d7fd22a --- /dev/null +++ b/tests/test_asyncio/scenario/fake_app.py @@ -0,0 +1,62 @@ +import typing +from asyncio import Event, Task, create_task +from unittest.mock import patch + +from redis.asyncio import Redis + + +class AsyncFakeApp: + + def __init__( + self, client: Redis, logic: typing.Callable[[Redis], typing.Awaitable[None]] + ): + self.client = client + self.logic = logic + self.disconnects = 0 + + async def run(self): + e = Event() + t = create_task(self._run_logic(e)) + return e, t + + async def _run_logic(self, e: Event): + with patch.object( + self.client, "_disconnect_raise", wraps=self.client._disconnect_raise + ) as spy: + while not e.is_set(): + await self.logic(self.client) + + self.disconnects = spy.call_count + + +class AsyncFakeSubscriber: + + def __init__( + self, client: Redis, logic: typing.Callable[[dict], typing.Awaitable[None]] + ): + self.client = client + self.logic = logic + self.disconnects = 0 + + async def run(self, channel: str) -> (Event, Task): + e = Event() + t = create_task(self._run_logic(e, channel)) + return e, t + + async def _run_logic(self, should_stop: Event, channel: str): + pubsub = self.client.pubsub() + + with patch.object( + pubsub, "_disconnect_raise_connect", wraps=pubsub._disconnect_raise_connect + ) as spy_pubsub: + await pubsub.subscribe(channel) + + while not should_stop.is_set(): + message = await pubsub.get_message( + ignore_subscribe_messages=True, timeout=1 + ) + + if message: + await self.logic(message) + + self.disconnects = spy_pubsub.call_count diff --git a/tests/test_asyncio/scenario/fault_injection_client.py b/tests/test_asyncio/scenario/fault_injection_client.py new file mode 100644 index 0000000000..4fcc685ded --- /dev/null +++ b/tests/test_asyncio/scenario/fault_injection_client.py @@ -0,0 +1,40 @@ +import aiohttp + + +class TriggeredAction: + + def __init__(self, client: "AsyncFaultInjectionClient", data: dict): + self.client = client + self.action_id = data["action_id"] + self.data = data + + async def refresh(self): + self.data = await self.client.get_action(self.action_id) + + @property + def status(self): + if "status" not in self.data: + return "pending" + return self.data["status"] + + async def wait_until_complete(self): + while self.status not in ("success", "failed"): + await self.refresh() + return self.status + + +class AsyncFaultInjectionClient: + def __init__(self, base_url: str = "http://127.0.0.1:20324"): + self.base_url = base_url + self.session = aiohttp.ClientSession() + + async def trigger_action(self, action_type: str, parameters: dict): + async with self.session.post( + f"{self.base_url}/action", + json={"type": action_type, "parameters": parameters}, + ) as response: + return TriggeredAction(self, await response.json()) + + async def get_action(self, action_id: str): + async with self.session.get(f"{self.base_url}/action/{action_id}") as response: + return await response.json() diff --git a/tests/test_asyncio/scenario/test_connection_interruptions.py b/tests/test_asyncio/scenario/test_connection_interruptions.py new file mode 100644 index 0000000000..c18720cf40 --- /dev/null +++ b/tests/test_asyncio/scenario/test_connection_interruptions.py @@ -0,0 +1,150 @@ +import asyncio +from typing import List + +import pytest +from redis.asyncio import BusyLoadingError, Redis +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.exceptions import ConnectionError as RedisConnectionError +from redis.exceptions import TimeoutError +from tests.scenario import Endpoint, get_endpoint + +from .fake_app import AsyncFakeApp, AsyncFakeSubscriber +from .fault_injection_client import AsyncFaultInjectionClient + + +@pytest.fixture +async def endpoint_name(): + return "re-standalone" + + +@pytest.fixture +async def endpoint(request: pytest.FixtureRequest, endpoint_name: str): + try: + return get_endpoint(request, endpoint_name) + except FileNotFoundError as e: + pytest.skip( + f"Skipping scenario test because endpoints file is missing: {str(e)}" + ) + + +@pytest.fixture +async def clients( + request: pytest.FixtureRequest, endpoint: Endpoint, create_redis: callable +): + # Use Recommended settings + retry = Retry(ExponentialBackoff(base=1), 5) + + clients = [] + + for _ in range(2): + client = await create_redis( + endpoint.url, + decode_responses=True, + retry=retry, + retry_on_error=[ + BusyLoadingError, + RedisConnectionError, + TimeoutError, + # FIXME: This is a workaround for a bug in redis-py + # https://github.com/redis/redis-py/issues/3203 + ConnectionError, + OSError, + ], + retry_on_timeout=True, + ) + await client.flushdb() + clients.append(client) + + return clients + + +@pytest.fixture +async def fault_injection_client(request: pytest.FixtureRequest): + return AsyncFaultInjectionClient() + + +@pytest.mark.parametrize("action", ("dmc_restart", "network_failure")) +async def test_connection_interruptions( + clients: List[Redis], + endpoint: Endpoint, + fault_injection_client: AsyncFaultInjectionClient, + action: str, +): + client = clients.pop() + app = AsyncFakeApp(client, lambda c: c.set("foo", "bar")) + + stop_app, task = await app.run() + + triggered_action = await fault_injection_client.trigger_action( + action, {"bdb_id": endpoint.bdb_id} + ) + + await triggered_action.wait_until_complete() + + stop_app.set() + await task + + if triggered_action.status == "failed": + pytest.fail(f"Action failed: {triggered_action.data['error']}") + + assert app.disconnects > 0 + + +@pytest.mark.parametrize("action", ("dmc_restart",)) # "network_failure")) +async def test_pubsub_with_connection_interruptions( + clients: List[Redis], + endpoint: Endpoint, + fault_injection_client: AsyncFaultInjectionClient, + action: str, +): + channel = "test" + + received_messages = [] + + async def read_message(message): + nonlocal received_messages + if message and message["type"] == "message": + received_messages.append(message["data"]) + + messages_sent = 0 + + async def publish_message(c): + nonlocal messages_sent, channel + messages_sent += 1 + await c.publish(channel, messages_sent) + + subscriber_client = clients.pop() + publisher_client = clients.pop() + + subscriber = AsyncFakeSubscriber(subscriber_client, read_message) + stop_subscriber, subscriber_t = await subscriber.run(channel) + + publisher = AsyncFakeApp(publisher_client, publish_message) + stop_publisher, publisher_t = await publisher.run() + + triggered_action = await fault_injection_client.trigger_action( + action, {"bdb_id": endpoint.bdb_id} + ) + + await triggered_action.wait_until_complete() + last_message_sent_after_trigger = messages_sent + + if triggered_action.status == "failed": + pytest.fail(f"Action failed: {triggered_action.data['error']}") + + await asyncio.sleep(3) + + stop_publisher.set() + await publisher_t + + stop_subscriber.set() + await subscriber_t + + assert publisher.disconnects > 0 + assert subscriber.disconnects > 0 + + assert ( + last_message_sent_after_trigger < messages_sent + ), "No messages were sent after the failure" + assert int(received_messages[-1]) == messages_sent, "Not all messages were received"