diff --git a/osbrain/agent.py b/osbrain/agent.py index 602b9ab..16cdefe 100644 --- a/osbrain/agent.py +++ b/osbrain/agent.py @@ -36,6 +36,7 @@ from .address import AgentAddressKind from .address import AgentAddressSerializer from .address import AgentChannel +from .address import AgentChannelKind from .address import address_to_host_port from .address import guess_kind from .proxy import Proxy @@ -543,6 +544,15 @@ def register(self, socket, address, alias=None, handler=None): self.poller.register(socket, zmq.POLLIN) self._set_handler(socket, handler) + def get_handler(self, alias): + """ + Get the handler associated to a socket given the socket alias. + + Ideally, this should only be called for alias that represent a + SUB socket. + """ + return self.handler[self.socket[alias]] + def _set_handler(self, socket, handlers): """ Set the socket handler(s). @@ -647,7 +657,7 @@ def _bind_address(self, kind, alias=None, handler=None, addr=None, self.register(socket, server_address, alias, handler) # SUB sockets are a special case if kind == 'SUB': - self._subscribe(server_address, handler) + self.subscribe(server_address, handler) return server_address def _bind_channel(self, kind, alias=None, handler=None, addr=None, @@ -783,7 +793,7 @@ def _connect_address(self, server_address, alias=None, handler=None): if client_address.kind == 'SUB': if not alias: alias = client_address - self._subscribe(alias, handler) + self.subscribe(alias, handler) return client_address def _connect_channel(self, channel, alias=None, handler=None): @@ -912,7 +922,7 @@ def _handle_async_requests(self, data): else: handler(self, response) - def _subscribe(self, alias: str, handlers: Dict[Union[bytes, str], Any]): + def subscribe(self, alias: str, handlers: Dict[Union[bytes, str], Any]): """ Subscribe the agent to another agent. @@ -965,6 +975,9 @@ def set_attr(self, **kwargs): def get_attr(self, name): return getattr(self, name) + def del_attr(self, name): + delattr(self, name) + def set_method(self, *args, **kwargs): """ Set object methods. @@ -1508,6 +1521,20 @@ def close_sockets(self): for sock in self.get_unique_external_zmq_sockets(): sock.close(linger=get_linger()) + def get_uuid_used_as_alias_for_sub_in_sync_pub(self, client_alias): + """ + Return the uuid that was used as the alias for the SUB socket + when a connection to a SYNC_PUB channel was made. + """ + channel = self.addr(client_alias) + if channel.kind != AgentChannelKind('SYNC_SUB'): + raise ValueError('Incorrect channel kind: {}'.format(channel.kind)) + client_addr = channel.twin().sender.twin() + addr_to_access_uuid = self.addr(client_addr) + uuid = self._async_req_uuid[addr_to_access_uuid] + + return uuid + def ping(self): """ A test method to check the readiness of the agent. Used for testing diff --git a/osbrain/helper.py b/osbrain/helper.py index 7b0a289..2fee21a 100644 --- a/osbrain/helper.py +++ b/osbrain/helper.py @@ -192,3 +192,36 @@ def wait_agent_attr(agent, name='received', length=None, data=None, value=None, break time.sleep(0.01) return False + + +def synchronize_sync_pub(server, server_alias, client, client_alias): + ''' + Create a SYNC_PUB/SYNC_SUB channel and connect both agents. + + Make sure they have stablished the PUB/SUB communication within the + SYNC_PUB/SYNC_SUB channel before returning. This will guarantee that + no PUB messages are lost. + ''' + def assert_receive(agent, message, topic=None): + try: + agent.get_attr('_tmp_attr') + agent.set_attr(_tmp_attr=True) + except AttributeError: # Attribute already deleted + pass + + uuid = client.get_uuid_used_as_alias_for_sub_in_sync_pub(client_alias) + + # Set a temporary custom handler + client.set_attr(_tmp_attr=False) + original_handler = client.get_handler(uuid) + client.subscribe(uuid, assert_receive) + + # Send messages through the PUB socket until the client receives them + server.each(0.1, 'send', server_alias, 'Synchronize', alias='_tmp_timer') + assert wait_agent_attr(client, name='_tmp_attr', value=True, timeout=5) + server.stop_timer('_tmp_timer') + + # Restore the original handler, now that the connection is guaranteed + client.subscribe(uuid, original_handler) + + client.del_attr('_tmp_attr') diff --git a/osbrain/tests/test_agent.py b/osbrain/tests/test_agent.py index 17ab67e..f69ddd7 100644 --- a/osbrain/tests/test_agent.py +++ b/osbrain/tests/test_agent.py @@ -421,6 +421,60 @@ def test_invalid_handlers(nsproxy): agent.bind('REP', handler=1.234) +def test_get_handler(nsproxy): + """ + Make sure the actual handler is returned. + """ + server = run_agent('server') + client = run_agent('client') + + pub_addr = server.bind('PUB', alias='pub') + client.connect(pub_addr, alias='sub', handler=receive) + + assert client.get_handler('sub') + with pytest.raises(KeyError): + server.get_handler('pub') + + +def test_get_uuid_used_as_alias_for_sub_in_sync_pub_sync(nsproxy): + """ + The function should only work for SYNC_SUB channels, and should raise + an exception in any other case. + """ + server = run_agent('server') + client = run_agent('client') + + sync_pub_addr = server.bind('SYNC_PUB', alias='sync_pub', handler=receive) + client.connect(sync_pub_addr, alias='sync_sub', handler=receive) + + # Should work for SYNC_SUB channels + assert client.get_uuid_used_as_alias_for_sub_in_sync_pub('sync_sub') + + # Should not work for other channels + with pytest.raises(ValueError): + server.get_uuid_used_as_alias_for_sub_in_sync_pub('sync_pub') + + +def test_get_uuid_used_as_alias_for_sub_in_sync_pub_async(nsproxy): + """ + The function should only work for SYNC_SUB channels, and should raise + an exception in any other case. + """ + server = run_agent('server') + client = run_agent('client') + + async_rep_addr = server.bind('ASYNC_REP', alias='async_rep', + handler=receive) + client.connect(async_rep_addr, alias='async_req', handler=receive) + + # Should not work for either channel + with pytest.raises(ValueError): + client.get_uuid_used_as_alias_for_sub_in_sync_pub('async_req') + + with pytest.raises(ValueError): + server.get_uuid_used_as_alias_for_sub_in_sync_pub('async_rep') + + def test_log_levels(nsproxy): """ Test different log levels: info, warning, error and debug. Debug messages diff --git a/osbrain/tests/test_agent_pubsub_topics.py b/osbrain/tests/test_agent_pubsub_topics.py index c122c5f..5273e2c 100644 --- a/osbrain/tests/test_agent_pubsub_topics.py +++ b/osbrain/tests/test_agent_pubsub_topics.py @@ -4,6 +4,7 @@ from osbrain import run_agent from osbrain.address import AgentAddressSerializer +from osbrain.helper import wait_agent_attr from common import nsproxy # pragma: no flakes from common import append_received @@ -164,3 +165,41 @@ def test_pubsub_topics_raw(nsproxy, serializer): assert b'fooWorld' in a5.get_attr('received') assert b'foobarFOO' in a5.get_attr('received') assert b'foBAR' in a5.get_attr('received') + + +def test_subscribe(nsproxy): + """ + Test the `subscribe` function works as expected for SUB sockets. + """ + def receive_square(agent, message, topic=None): + agent.received.append(message**2) + + def receive_cube(agent, message, topic=None): + agent.received.append(message**3) + + server = run_agent('server') + client = run_agent('client') + + addr = server.bind('PUB', alias='pub') + client.set_attr(received=[]) + client.connect(addr, alias='sub', handler=receive) + + # Give some time for the client to connect + time.sleep(0.1) + + server.send('pub', 2) + assert wait_agent_attr(client, data=2) + + client.subscribe('sub', handlers={'foo': receive_square, + 'bar': receive_cube}) + + server.send('pub', 2, topic='foo') + server.send('pub', 2, topic='bar') + server.send('pub', 3) + + # Check new handlers were used for different topics + assert wait_agent_attr(client, data=4) + assert wait_agent_attr(client, data=8) + + # No longer subscribed to all topics + assert not wait_agent_attr(client, data=3) diff --git a/osbrain/tests/test_helper.py b/osbrain/tests/test_helper.py index 6565672..39328a8 100644 --- a/osbrain/tests/test_helper.py +++ b/osbrain/tests/test_helper.py @@ -9,9 +9,41 @@ from osbrain.helper import agent_dies from osbrain.helper import attribute_match_all from osbrain.helper import wait_agent_attr +from osbrain.helper import synchronize_sync_pub from common import nsproxy # pragma: no flakes from common import agent_logger # pragma: no flakes +from common import append_received + + +def test_synchronize_sync_pub(nsproxy): + """ + Publications sent through SYNC_PUB/SYNC_SUB after synchronazing them + should be received without exception. + """ + server = run_agent('server') + client = run_agent('client') + + addr = server.bind('SYNC_PUB', alias='sync_pub', handler=append_received) + + client.set_attr(received=[]) + client.connect(addr, alias='sync_sub', handler=append_received) + + # Guarantee the PUB/SUB is stablished + synchronize_sync_pub(server, 'sync_pub', client, 'sync_sub') + + # Send the message only once + server.send('sync_pub', 'Hello') + + assert wait_agent_attr(client, name='received', data='Hello') + + # Check that no temporary attributes remain + assert 'Synchronize' not in client.get_attr('received') + + with pytest.raises(AttributeError): + client.get_attr('_tmp_attr') + + assert '_tmp_timer' not in server.list_timers() def test_agent_dies(nsproxy):