Skip to content

Commit

Permalink
Move redis subscribe to poke() method in Redis Sensor (apache#32984)
Browse files Browse the repository at this point in the history
In RedisPubSubSensor subscription has been done in constructor,
which was pretty wrong - for example it means that when scheduler
parses the sensor, it involves subscribing to the messages and
commmunication with redis DB.

This PR moves subscription to "poke()" method, which is executed
on worker instead.
  • Loading branch information
potiuk authored Aug 1, 2023
1 parent 05494e5 commit 17a3dd4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions airflow/providers/redis/sensors/redis_pub_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow.providers.redis.hooks.redis import RedisHook
Expand All @@ -41,8 +42,10 @@ def __init__(self, *, channels: list[str] | str, redis_conn_id: str, **kwargs) -
super().__init__(**kwargs)
self.channels = channels
self.redis_conn_id = redis_conn_id
self.pubsub = RedisHook(redis_conn_id=self.redis_conn_id).get_conn().pubsub()
self.pubsub.subscribe(self.channels)

@cached_property
def pubsub(self):
return RedisHook(redis_conn_id=self.redis_conn_id).get_conn().pubsub()

def poke(self, context: Context) -> bool:
"""
Expand All @@ -54,7 +57,7 @@ def poke(self, context: Context) -> bool:
:return: ``True`` if message (with type 'message') is available or ``False`` if not
"""
self.log.info("RedisPubSubSensor checking for message on channels: %s", self.channels)

self.pubsub.subscribe(self.channels)
message = self.pubsub.get_message()
self.log.info("Message %s from channel %s", message, self.channels)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def test_poke_true(self):

hook = RedisHook(redis_conn_id="redis_default")
redis = hook.get_conn()
redis.publish("test", "message")

result = sensor.poke(self.mock_context)
assert not result
redis.publish("test", "message")

for _ in range(1, 10):
result = sensor.poke(self.mock_context)
Expand Down

0 comments on commit 17a3dd4

Please sign in to comment.