From c39700bc2283500e0327ed3d7f00cf0056f2dfed Mon Sep 17 00:00:00 2001 From: Yadu Nand Babuji Date: Wed, 19 Jul 2023 15:18:16 -0500 Subject: [PATCH] Update HTEX.Interchange to listen on a single interface (#2828) This PR updates the HTEX.Interchange to listen for connections from the managers only on a specific interface rather than the current default of binding to all interfaces. Binding to all interfaces is generally frowned upon on login nodes, when all we need is to allow connections from the internal network. Here are the changes: Pass HighThroughputExecutor.address to Interchange.interchange_address Interchange will bind only to the interchange_address if it is specified instead of the binding to zmq:* or 0.0.0.0 which is the current default. Adding tests for the Interchange Please note that configs which specify HTEX(address="localhost") or similar where a non IPv4 address will now fail --- parsl/executors/high_throughput/executor.py | 8 ++- .../executors/high_throughput/interchange.py | 72 ++++--------------- parsl/tests/test_htex/__init__.py | 0 .../tests/test_htex/test_htex_zmq_binding.py | 46 ++++++++++++ 4 files changed, 66 insertions(+), 60 deletions(-) create mode 100644 parsl/tests/test_htex/__init__.py create mode 100644 parsl/tests/test_htex/test_htex_zmq_binding.py diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py index 66702c68fa..fbb4004504 100644 --- a/parsl/executors/high_throughput/executor.py +++ b/parsl/executors/high_throughput/executor.py @@ -97,9 +97,10 @@ class HighThroughputExecutor(BlockProviderExecutor, RepresentationMixin): address : string An address to connect to the main Parsl process which is reachable from the network in which - workers will be running. This can be either a hostname as returned by ``hostname`` or an - IP address. Most login nodes on clusters have several network interfaces available, only - some of which can be reached from the compute nodes. + workers will be running. This field expects an IPv4 address (xxx.xxx.xxx.xxx). + Most login nodes on clusters have several network interfaces available, only some of which + can be reached from the compute nodes. This field can be used to limit the executor to listen + only on a specific interface, and limiting connections to the internal network. By default, the executor will attempt to enumerate and connect through all possible addresses. Setting an address here overrides the default behavior. default=None @@ -470,6 +471,7 @@ def _start_local_interchange_process(self): kwargs={"client_ports": (self.outgoing_q.port, self.incoming_q.port, self.command_client.port), + "interchange_address": self.address, "worker_ports": self.worker_ports, "worker_port_range": self.worker_port_range, "hub_address": self.hub_address, diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py index d522897d21..d032144655 100644 --- a/parsl/executors/high_throughput/interchange.py +++ b/parsl/executors/high_throughput/interchange.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -import argparse import zmq import os import sys @@ -14,7 +13,7 @@ import threading import json -from typing import cast, Any, Dict, Set +from typing import cast, Any, Dict, Set, Optional from parsl.utils import setproctitle from parsl.version import VERSION as PARSL_VERSION @@ -29,6 +28,9 @@ HEARTBEAT_CODE = (2 ** 32) - 1 PKL_HEARTBEAT_CODE = pickle.dumps((2 ** 32) - 1) +LOGGER_NAME = "interchange" +logger = logging.getLogger(LOGGER_NAME) + class ManagerLost(Exception): ''' Task lost due to manager loss. Manager is considered lost when multiple heartbeats @@ -66,7 +68,7 @@ class Interchange: """ def __init__(self, client_address="127.0.0.1", - interchange_address="127.0.0.1", + interchange_address: Optional[str] = None, client_ports=(50055, 50056, 50057), worker_ports=None, worker_port_range=(54000, 55000), @@ -83,8 +85,9 @@ def __init__(self, client_address : str The ip address at which the parsl client can be reached. Default: "127.0.0.1" - interchange_address : str - The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1" + interchange_address : Optional str + If specified the interchange will only listen on this address for connections from workers + else, it binds to all addresses. client_ports : triple(int, int, int) The ports at which the client can be reached @@ -125,7 +128,7 @@ def __init__(self, logger.debug("Initializing Interchange process") self.client_address = client_address - self.interchange_address = interchange_address + self.interchange_address: str = interchange_address or "*" self.poll_period = poll_period logger.info("Attempting connection to client at {} on ports: {},{},{}".format( @@ -160,14 +163,14 @@ def __init__(self, self.worker_task_port = self.worker_ports[0] self.worker_result_port = self.worker_ports[1] - self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port)) - self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port)) + self.task_outgoing.bind(f"tcp://{self.interchange_address}:{self.worker_task_port}") + self.results_incoming.bind(f"tcp://{self.interchange_address}:{self.worker_result_port}") else: - self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*', + self.worker_task_port = self.task_outgoing.bind_to_random_port(f"tcp://{self.interchange_address}", min_port=worker_port_range[0], max_port=worker_port_range[1], max_tries=100) - self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*', + self.worker_result_port = self.results_incoming.bind_to_random_port(f"tcp://{self.interchange_address}", min_port=worker_port_range[0], max_port=worker_port_range[1], max_tries=100) @@ -574,7 +577,7 @@ def expire_bad_managers(self, interesting_managers, hub_channel): interesting_managers.remove(manager_id) -def start_file_logger(filename, name='interchange', level=logging.DEBUG, format_string=None): +def start_file_logger(filename, level=logging.DEBUG, format_string=None): """Add a stream log handler. Parameters @@ -582,8 +585,6 @@ def start_file_logger(filename, name='interchange', level=logging.DEBUG, format_ filename: string Name of the file to write logs to. Required. - name: string - Logger name. Default="parsl.executors.interchange" level: logging.LEVEL Set the logging level. Default=logging.DEBUG - format_string (string): Set the format string @@ -598,7 +599,7 @@ def start_file_logger(filename, name='interchange', level=logging.DEBUG, format_ format_string = "%(asctime)s.%(msecs)03d %(name)s:%(lineno)d %(processName)s(%(process)d) %(threadName)s %(funcName)s [%(levelname)s] %(message)s" global logger - logger = logging.getLogger(name) + logger = logging.getLogger(LOGGER_NAME) logger.setLevel(level) handler = logging.FileHandler(filename) handler.setLevel(level) @@ -619,46 +620,3 @@ def starter(comm_q, *args, **kwargs): comm_q.put((ic.worker_task_port, ic.worker_result_port)) ic.start() - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--client_address", - help="Client address") - parser.add_argument("-l", "--logdir", default="parsl_worker_logs", - help="Parsl worker log directory") - parser.add_argument("-t", "--task_url", - help="REQUIRED: ZMQ url for receiving tasks") - parser.add_argument("-r", "--result_url", - help="REQUIRED: ZMQ url for posting results") - parser.add_argument("-p", "--poll_period", - help="REQUIRED: poll period used for main thread") - parser.add_argument("--worker_ports", default=None, - help="OPTIONAL, pair of workers ports to listen on, eg --worker_ports=50001,50005") - parser.add_argument("-d", "--debug", action='store_true', - help="Count of apps to launch") - - args = parser.parse_args() - - # Setup logging - global logger - format_string = "%(asctime)s %(name)s:%(lineno)d [%(levelname)s] %(message)s" - - logger = logging.getLogger("interchange") - logger.setLevel(logging.DEBUG) - handler = logging.StreamHandler() - handler.setLevel('DEBUG' if args.debug is True else 'INFO') - formatter = logging.Formatter(format_string, datefmt='%Y-%m-%d %H:%M:%S') - handler.setFormatter(formatter) - logger.addHandler(handler) - - logger.debug("Starting Interchange") - - optionals = {} - - if args.worker_ports: - optionals['worker_ports'] = [int(i) for i in args.worker_ports.split(',')] - - ic = Interchange(**optionals) - ic.start() diff --git a/parsl/tests/test_htex/__init__.py b/parsl/tests/test_htex/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/parsl/tests/test_htex/test_htex_zmq_binding.py b/parsl/tests/test_htex/test_htex_zmq_binding.py new file mode 100644 index 0000000000..442f0431b6 --- /dev/null +++ b/parsl/tests/test_htex/test_htex_zmq_binding.py @@ -0,0 +1,46 @@ +import logging + +import psutil +import pytest +import zmq + +from parsl.executors.high_throughput.interchange import Interchange + + +def test_interchange_binding_no_address(): + ix = Interchange() + assert ix.interchange_address == "*" + + +def test_interchange_binding_with_address(): + # Using loopback address + address = "127.0.0.1" + ix = Interchange(interchange_address=address) + assert ix.interchange_address == address + + +def test_interchange_binding_with_non_ipv4_address(): + # Confirm that a ipv4 address is required + address = "localhost" + with pytest.raises(zmq.error.ZMQError): + Interchange(interchange_address=address) + + +def test_interchange_binding_bad_address(): + """ Confirm that we raise a ZMQError when a bad address is supplied""" + address = "550.0.0.0" + with pytest.raises(zmq.error.ZMQError): + Interchange(interchange_address=address) + + +def test_limited_interface_binding(): + """ When address is specified the worker_port would be bound to it rather than to 0.0.0.0""" + address = "127.0.0.1" + ix = Interchange(interchange_address=address) + ix.worker_result_port + proc = psutil.Process() + conns = proc.connections(kind="tcp") + + matched_conns = [conn for conn in conns if conn.laddr.port == ix.worker_result_port] + assert len(matched_conns) == 1 + assert matched_conns[0].laddr.ip == address