Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reconnection with unit test coverage #109

Merged
merged 4 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions neon_mq_connector/consumers/select_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import threading
import time

from asyncio import Event, run
from asyncio import Event
from typing import Optional

import pika.exceptions
Expand Down Expand Up @@ -185,10 +185,14 @@ def on_message(self, channel, method, properties, body):
self.error_func(self, e)

def on_close(self, _, e):
self._consumer_started.clear()
if isinstance(e, pika.exceptions.ConnectionClosed):
LOG.info(f"Connection closed normally: {e}")
if not self._stopping:
else:
LOG.error(f"Closing MQ connection due to exception: {e}")
if not self._stopping:
# Connection was gracefully closed by the server. Try to re-connect
LOG.info(f"Trying to reconnect after server closed connection")
self.reconnect()

@property
Expand All @@ -200,10 +204,9 @@ def is_consuming(self) -> bool:
return self._consumer_started.is_set()

def run(self):
"""Starting connnection io loop """
"""Starting connection io loop """
if not self.is_consuming:
try:
super(SelectConsumerThread, self).run()
self.connection: pika.SelectConnection = self.create_connection()
self.connection.ioloop.start()
except (pika.exceptions.ChannelClosed,
Expand All @@ -217,6 +220,8 @@ def run(self):
LOG.error(f"Failed to start io loop on consumer thread {self.name!r}: {e}")
self._close_connection()
self.error_func(self, e)
else:
LOG.warning("Consumer already running!")

def _close_connection(self, mark_consumer_as_dead: bool = True):
try:
Expand All @@ -228,9 +233,16 @@ def _close_connection(self, mark_consumer_as_dead: bool = True):
raise TimeoutError(f"Timeout waiting for channel close. "
f"is_closed={self.channel.is_closed}")
LOG.info(f"Channel closed")

# Wait for the connection to close
waiter = threading.Event()
while not self.connection.is_closed:
waiter.wait(1)
LOG.info(f"Connection closed")

if self.connection:
self.connection.ioloop.stop()
self.connection = None
# self.connection = None
except Exception as e:
LOG.error(f"Failed to close connection for Consumer {self.name!r}: {e}")
self._is_consuming = False
Expand All @@ -240,8 +252,10 @@ def _close_connection(self, mark_consumer_as_dead: bool = True):
else:
self._stopping = False

def reconnect(self, wait_interval: int = 1):
def reconnect(self, wait_interval: int = 5):
self._close_connection(mark_consumer_as_dead=False)
# TODO: Find a better way to wait for shutdown/server restart. This will
# fail to reconnect if the server isn't back up within `wait_interval`
time.sleep(wait_interval)
self.run()

Expand Down
50 changes: 46 additions & 4 deletions tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import asyncio
from time import sleep
from unittest.mock import Mock

import pytest

from time import sleep
from unittest.mock import Mock
from unittest import TestCase

from pika.connection import ConnectionParameters
from pika.credentials import PlainCredentials
from pika.exchange_type import ExchangeType
Expand Down Expand Up @@ -171,3 +169,47 @@ def test_select_consumer_thread(self):
self.assertFalse(test_thread.is_consuming)
self.assertFalse(test_thread.is_consumer_alive)
test_thread.on_close.assert_not_called()

def test_handle_reconnection(self):
from neon_mq_connector.consumers.select_consumer import SelectConsumerThread
connection_params = ConnectionParameters(host='localhost',
port=self.rmq_instance.port,
virtual_host="/neon_testing",
credentials=PlainCredentials(
"test_user",
"test_password"))
queue = "test_q"
callback = Mock()
error = Mock()

# Valid thread
test_thread = SelectConsumerThread(connection_params, queue, callback,
error)
test_thread.on_connected = Mock(side_effect=test_thread.on_connected)
test_thread.on_channel_open = Mock(side_effect=test_thread.on_channel_open)
test_thread.on_close = Mock(side_effect=test_thread.on_close)

test_thread.start()
while not test_thread.is_consuming:
sleep(0.1)

test_thread.on_connected.assert_called_once()
test_thread.on_channel_open.assert_called_once()
test_thread.on_close.assert_not_called()

self.rmq_instance.stop()
sleep(1) # Wait for the client to finish disconnecting
test_thread.on_close.assert_called_once()
self.assertFalse(test_thread.is_consuming)
self.assertTrue(test_thread.is_consumer_alive)

self.rmq_instance.start()
# TODO: Wait for re-connection
while not test_thread.is_consuming:
sleep(0.1)
self.assertTrue(test_thread.is_consuming)
self.assertTrue(test_thread.is_consumer_alive)

test_thread.join(30)
self.assertFalse(test_thread.is_consuming)
self.assertFalse(test_thread.is_consumer_alive)
Loading