Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scenario tests for connection interruptions #3331

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ uvloop
vulture>=2.3.0
wheel>=0.30.0
numpy>=1.24.0
requests>=2.23.0
aiohttp>=3.0.0
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def pytest_addoption(parser):
help="Name of the Redis master service that the sentinels are monitoring",
)

parser.addoption(
"--endpoints-config",
action="store",
default="endpoints.json",
help="Path to the Redis endpoints configuration file",
)


def _get_info(redis_url):
client = redis.Redis.from_url(redis_url)
Expand Down
57 changes: 57 additions & 0 deletions tests/scenario/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import dataclasses
import json
import os.path
from typing import List
from urllib.parse import urlparse

import pytest


@dataclasses.dataclass
class Endpoint:
bdb_id: int
username: str
password: str
tls: bool
endpoints: List[str]

@property
def url(self):
parsed_url = urlparse(self.endpoints[0])

if self.tls:
parsed_url = parsed_url._replace(scheme="rediss")

domain = parsed_url.netloc.split("@")[-1]
domain = f"{self.username}:{self.password}@{domain}"

parsed_url = parsed_url._replace(netloc=domain)

return parsed_url.geturl()

@classmethod
def from_dict(cls, data: dict):
field_names = set(f.name for f in dataclasses.fields(cls))
return cls(**{k: v for k, v in data.items() if k in field_names})


def get_endpoint(request: pytest.FixtureRequest, endpoint_name: str) -> Endpoint:
endpoints_config_path = request.config.getoption("--endpoints-config")

if not (endpoints_config_path and os.path.exists(endpoints_config_path)):
raise FileNotFoundError(
f"Endpoints config file not found: {endpoints_config_path}"
)

try:
with open(endpoints_config_path, "r") as f:
endpoints_config = json.load(f)
except Exception as e:
raise ValueError(
f"Failed to load endpoints config file: {endpoints_config_path}"
) from e

if not (isinstance(endpoints_config, dict) and endpoint_name in endpoints_config):
raise ValueError(f"Endpoint not found in config: {endpoint_name}")

return Endpoint.from_dict(endpoints_config.get(endpoint_name))
63 changes: 63 additions & 0 deletions tests/scenario/fake_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import multiprocessing
import typing
from multiprocessing import Event as PEvent
from multiprocessing import Process
from threading import Event, Thread
from unittest.mock import patch

from redis import Redis


class FakeApp:

def __init__(self, client: Redis, logic: typing.Callable[[Redis], None]):
self.client = client
self.logic = logic
self.disconnects = 0

def run(self) -> (Event, Thread):
e = Event()
t = Thread(target=self._run_logic, args=(e,))
t.start()
return e, t

def _run_logic(self, e: Event):
with patch.object(
self.client, "_disconnect_raise", wraps=self.client._disconnect_raise
) as spy:
while not e.is_set():
self.logic(self.client)

self.disconnects = spy.call_count


class FakeSubscriber:

def __init__(self, client: Redis, logic: typing.Callable[[dict], None]):
self.client = client
self.logic = logic
self.disconnects = multiprocessing.Value("i", 0)

def run(self, channel: str) -> (PEvent, Process):
e, started = PEvent(), PEvent()
p = Process(target=self._run_logic, args=(e, started, channel))
p.start()
return e, started, p

def _run_logic(self, should_stop: PEvent, started: PEvent, channel: str):
pubsub = self.client.pubsub()

with patch.object(
pubsub, "_disconnect_raise_connect", wraps=pubsub._disconnect_raise_connect
) as spy_pubsub:
pubsub.subscribe(channel)

started.set()

while not should_stop.is_set():
message = pubsub.get_message(ignore_subscribe_messages=True, timeout=1)

if message:
self.logic(message)

self.disconnects.value = spy_pubsub.call_count
39 changes: 39 additions & 0 deletions tests/scenario/fault_injection_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import requests


class TriggeredAction:

def __init__(self, client: "FaultInjectionClient", data: dict):
self.client = client
self.action_id = data["action_id"]
self.data = data

def refresh(self):
self.data = self.client.get_action(self.action_id)

@property
def status(self):
if "status" not in self.data:
return "pending"
return self.data["status"]

def wait_until_complete(self):
while self.status not in ("success", "failed"):
self.refresh()
return self.status


class FaultInjectionClient:
def __init__(self, base_url: str = "http://127.0.0.1:20324"):
self.base_url = base_url

def trigger_action(self, action_type: str, parameters: dict):
response = requests.post(
f"{self.base_url}/action",
json={"type": action_type, "parameters": parameters},
)
return TriggeredAction(self, response.json())

def get_action(self, action_id: str):
response = requests.get(f"{self.base_url}/action/{action_id}")
return response.json()
157 changes: 157 additions & 0 deletions tests/scenario/test_connection_interruptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import multiprocessing
import time
from typing import List

import pytest
from redis import BusyLoadingError, Redis
from redis.backoff import ExponentialBackoff
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import TimeoutError as RedisTimeoutError
from redis.retry import Retry

from ..conftest import _get_client
from . import Endpoint, get_endpoint
from .fake_app import FakeApp, FakeSubscriber
from .fault_injection_client import FaultInjectionClient


@pytest.fixture
def endpoint_name():
return "re-standalone"


@pytest.fixture
def endpoint(request: pytest.FixtureRequest, endpoint_name: str):
try:
return get_endpoint(request, endpoint_name)
except FileNotFoundError as e:
pytest.skip(
f"Skipping scenario test because endpoints file is missing: {str(e)}"
)


@pytest.fixture
def clients(request: pytest.FixtureRequest, endpoint: Endpoint):
# Use Recommended settings
retry = Retry(ExponentialBackoff(base=1), 3)

clients = []

for _ in range(2):
r = _get_client(
Redis,
request,
decode_responses=True,
from_url=endpoint.url,
retry=retry,
retry_on_error=[
BusyLoadingError,
RedisConnectionError,
RedisTimeoutError,
# FIXME: This is a workaround for a bug in redis-py
# https://github.com/redis/redis-py/issues/3203
ConnectionError,
TimeoutError,
],
)
r.flushdb()
clients.append(r)
return clients


@pytest.fixture
def fault_injection_client(request: pytest.FixtureRequest):
return FaultInjectionClient()


@pytest.mark.parametrize("action", ("dmc_restart", "network_failure"))
def test_connection_interruptions(
clients: List[Redis],
endpoint: Endpoint,
fault_injection_client: FaultInjectionClient,
action: str,
):
client = clients.pop()
app = FakeApp(client, lambda c: c.set("foo", "bar"))

stop_app, thread = app.run()

triggered_action = fault_injection_client.trigger_action(
action, {"bdb_id": endpoint.bdb_id}
)

triggered_action.wait_until_complete()

stop_app.set()
thread.join()

if triggered_action.status == "failed":
pytest.fail(f"Action failed: {triggered_action.data['error']}")

assert app.disconnects > 0, "Client did not disconnect"


@pytest.mark.parametrize("action", ("dmc_restart", "network_failure"))
def test_pubsub_with_connection_interruptions(
clients: List[Redis],
endpoint: Endpoint,
fault_injection_client: FaultInjectionClient,
action: str,
):
channel = "test"

# Subscriber is executed in a separate process to ensure it reacts
# to the disconnection at the same time as the publisher
with multiprocessing.Manager() as manager:
received_messages = manager.list()

def read_message(message):
nonlocal received_messages
if message and message["type"] == "message":
received_messages.append(message["data"])

subscriber_client = clients.pop()
subscriber = FakeSubscriber(subscriber_client, read_message)
stop_subscriber, subscriber_started, subscriber_t = subscriber.run(channel)

# Allow subscriber subscribe to the channel
subscriber_started.wait(timeout=5)

messages_sent = 0

def publish_message(c):
nonlocal messages_sent, channel
messages_sent += 1
c.publish(channel, messages_sent)

publisher_client = clients.pop()
publisher = FakeApp(publisher_client, publish_message)
stop_publisher, publisher_t = publisher.run()

triggered_action = fault_injection_client.trigger_action(
action, {"bdb_id": endpoint.bdb_id}
)

triggered_action.wait_until_complete()
last_message_sent_after_trigger = messages_sent

time.sleep(3) # Wait for the publisher to send more messages

stop_publisher.set()
publisher_t.join()

stop_subscriber.set()
subscriber_t.join()

assert publisher.disconnects > 0
assert subscriber.disconnects.value > 0

if triggered_action.status == "failed":
pytest.fail(f"Action failed: {triggered_action.data['error']}")

assert (
last_message_sent_after_trigger < messages_sent
), "No messages were sent after the failure"
assert (
int(received_messages[-1]) == messages_sent
), "Not all messages were received"
Empty file.
Loading
Loading