Skip to content

Commit

Permalink
Use SentinelConnectionPoolProxy in asyncio module.
Browse files Browse the repository at this point in the history
  • Loading branch information
DABND19 committed Jul 18, 2024
1 parent 9607608 commit a745bde
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions redis/asyncio/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,55 @@ class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
pass


class SentinelConnectionPoolProxy:
def __init__(
self,
connection_pool,
is_master,
check_connection,
service_name,
sentinel_manager,
):
self.connection_pool_ref = weakref.ref(connection_pool)
self.is_master = is_master
self.check_connection = check_connection
self.service_name = service_name
self.sentinel_manager = sentinel_manager
self.reset()

def reset(self):
self.master_address = None
self.slave_rr_counter = None

async def get_master_address(self):
master_address = await self.sentinel_manager.discover_master(self.service_name)
if self.is_master and self.master_address != master_address:
self.master_address = master_address
# disconnect any idle connections so that they reconnect
# to the new master the next time that they are used.
connection_pool = self.connection_pool_ref()
if connection_pool is not None:
await connection_pool.disconnect(inuse_connections=False)
return master_address

async def rotate_slaves(self) -> AsyncIterator:
"""Round-robin slave balancer"""
slaves = await self.sentinel_manager.discover_slaves(self.service_name)
if slaves:
if self.slave_rr_counter is None:
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
for _ in range(len(slaves)):
self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves)
slave = slaves[self.slave_rr_counter]
yield slave
# Fallback to the master connection
try:
yield await self.get_master_address()
except MasterNotFoundError:
pass
raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")


class SentinelConnectionPool(ConnectionPool):
"""
Sentinel backed connection pool.
Expand All @@ -116,6 +165,44 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
)
self.is_master = kwargs.pop("is_master", True)
self.check_connection = kwargs.pop("check_connection", False)
self.proxy = SentinelConnectionPoolProxy(
connection_pool=self,
is_master=self.is_master,
check_connection=self.check_connection,
service_name=service_name,
sentinel_manager=sentinel_manager,
)
super().__init__(**kwargs)
self.connection_kwargs["connection_pool"] = weakref.proxy(self)
self.service_name = service_name
self.sentinel_manager = sentinel_manager

def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>"
)

def reset(self):
super().reset()
self.proxy.reset()

@property
def master_address(self):
return self.proxy.master_address

def owns_connection(self, connection: Connection):
check = not self.is_master or (
self.is_master and self.master_address == (connection.host, connection.port)
)
return check and super().owns_connection(connection)

async def get_master_address(self):
return await self.proxy.get_master_address()

def rotate_slaves(self) -> AsyncIterator:
"""Round-robin slave balancer"""
return self.proxy.rotate_slaves()
super().__init__(**kwargs)
self.connection_kwargs["connection_pool"] = weakref.proxy(self)
self.service_name = service_name
Expand Down

0 comments on commit a745bde

Please sign in to comment.