From 1146983fb0036bcc6a32be0f39f4b11203318b1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Mon, 4 Nov 2024 15:02:14 +0100 Subject: [PATCH] Create driver from QueueOptions instead of QueueConfig --- src/ert/ensemble_evaluator/_ensemble.py | 2 +- src/ert/scheduler/__init__.py | 20 +++++++++---------- src/ert/simulator/batch_simulator_context.py | 2 +- .../test_async_queue_execution.py | 2 +- .../unit_tests/scheduler/test_lsf_driver.py | 2 +- .../unit_tests/scheduler/test_scheduler.py | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index 14138b1ab42..ecc1d5c81d5 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -272,7 +272,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches raise ValueError("no config") # mypy try: - driver = create_driver(self._queue_config) + driver = create_driver(self._queue_config.queue_options) self._scheduler = Scheduler( driver, self.active_reals, diff --git a/src/ert/scheduler/__init__.py b/src/ert/scheduler/__init__.py index b71c7591a1d..349e1d3838a 100644 --- a/src/ert/scheduler/__init__.py +++ b/src/ert/scheduler/__init__.py @@ -15,21 +15,21 @@ from .slurm_driver import SlurmDriver if TYPE_CHECKING: - from ert.config.queue_config import QueueConfig + from ert.config.queue_config import QueueOptions -def create_driver(config: QueueConfig) -> Driver: - if config.queue_system == QueueSystem.LOCAL: - return LocalDriver(**config.queue_options.driver_options) - elif config.queue_system == QueueSystem.TORQUE: - return OpenPBSDriver(**config.queue_options.driver_options) - elif config.queue_system == QueueSystem.LSF: - return LsfDriver(**config.queue_options.driver_options) - elif config.queue_system == QueueSystem.SLURM: +def create_driver(queue_options: QueueOptions) -> Driver: + if queue_options.name == QueueSystem.LOCAL: + return LocalDriver() + elif queue_options.name == QueueSystem.TORQUE: + return OpenPBSDriver(**queue_options.driver_options) + elif queue_options.name == QueueSystem.LSF: + return LsfDriver(**queue_options.driver_options) + elif queue_options.name == QueueSystem.SLURM: return SlurmDriver( **dict( {"user": getpwuid(getuid()).pw_name}, - **config.queue_options.driver_options, + **queue_options.driver_options, ) ) else: diff --git a/src/ert/simulator/batch_simulator_context.py b/src/ert/simulator/batch_simulator_context.py index 23147381836..41126c60561 100644 --- a/src/ert/simulator/batch_simulator_context.py +++ b/src/ert/simulator/batch_simulator_context.py @@ -145,7 +145,7 @@ def __post_init__(self) -> None: """ Handle which can be used to query status and results for batch simulation. """ - driver = create_driver(self.queue_config) + driver = create_driver(self.queue_config.queue_options) self._scheduler = Scheduler(driver, max_running=self.queue_config.max_running) # fill in the missing geo_id data diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_async_queue_execution.py b/tests/ert/unit_tests/ensemble_evaluator/test_async_queue_execution.py index b61601317f2..72845b3f753 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_async_queue_execution.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_async_queue_execution.py @@ -14,7 +14,7 @@ async def test_happy_path( ensemble = make_ensemble(monkeypatch, tmpdir, 1, 1) queue = Scheduler( - driver=create_driver(queue_config), + driver=create_driver(queue_config.queue_options), realizations=ensemble.reals, ens_id="ee_0", ) diff --git a/tests/ert/unit_tests/scheduler/test_lsf_driver.py b/tests/ert/unit_tests/scheduler/test_lsf_driver.py index 3132d2965eb..411e5649807 100644 --- a/tests/ert/unit_tests/scheduler/test_lsf_driver.py +++ b/tests/ert/unit_tests/scheduler/test_lsf_driver.py @@ -194,7 +194,7 @@ async def test_submit_with_project_code(): "FORWARD_MODEL": [("FLOW",), ("ECLIPSE",), ("RMS",)], } queue_config = QueueConfig.from_dict(queue_config_dict) - driver: LsfDriver = create_driver(queue_config) + driver: LsfDriver = create_driver(queue_config.queue_options) await driver.submit(0, "sleep") assert f"-P {queue_config.queue_options.project_code}" in Path( "captured_bsub_args" diff --git a/tests/ert/unit_tests/scheduler/test_scheduler.py b/tests/ert/unit_tests/scheduler/test_scheduler.py index 3b5de6903d8..d6d77a4c5c1 100644 --- a/tests/ert/unit_tests/scheduler/test_scheduler.py +++ b/tests/ert/unit_tests/scheduler/test_scheduler.py @@ -638,7 +638,7 @@ def test_scheduler_create_lsf_driver(): ], } queue_config = QueueConfig.from_dict(queue_config_dict) - driver = create_driver(queue_config) + driver = create_driver(queue_config.queue_options) assert isinstance(driver, LsfDriver) assert str(driver._bsub_cmd) == bsub_cmd assert str(driver._bkill_cmd) == bkill_cmd @@ -678,7 +678,7 @@ def test_scheduler_create_openpbs_driver(): ], } queue_config = QueueConfig.from_dict(queue_config_dict) - driver = create_driver(queue_config) + driver = create_driver(queue_config.queue_options) assert isinstance(driver, OpenPBSDriver) assert driver._queue_name == queue_name assert driver._keep_qsub_output == True if keep_qsub_output == "True" else False