Skip to content

Commit

Permalink
Create driver from QueueOptions instead of QueueConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Nov 5, 2024
1 parent aba7e5c commit 1146983
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
raise ValueError("no config") # mypy

try:
driver = create_driver(self._queue_config)
driver = create_driver(self._queue_config.queue_options)
self._scheduler = Scheduler(
driver,
self.active_reals,
Expand Down
20 changes: 10 additions & 10 deletions src/ert/scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@
from .slurm_driver import SlurmDriver

if TYPE_CHECKING:
from ert.config.queue_config import QueueConfig
from ert.config.queue_config import QueueOptions


def create_driver(config: QueueConfig) -> Driver:
if config.queue_system == QueueSystem.LOCAL:
return LocalDriver(**config.queue_options.driver_options)
elif config.queue_system == QueueSystem.TORQUE:
return OpenPBSDriver(**config.queue_options.driver_options)
elif config.queue_system == QueueSystem.LSF:
return LsfDriver(**config.queue_options.driver_options)
elif config.queue_system == QueueSystem.SLURM:
def create_driver(queue_options: QueueOptions) -> Driver:
if queue_options.name == QueueSystem.LOCAL:
return LocalDriver()
elif queue_options.name == QueueSystem.TORQUE:
return OpenPBSDriver(**queue_options.driver_options)
elif queue_options.name == QueueSystem.LSF:
return LsfDriver(**queue_options.driver_options)
elif queue_options.name == QueueSystem.SLURM:
return SlurmDriver(
**dict(
{"user": getpwuid(getuid()).pw_name},
**config.queue_options.driver_options,
**queue_options.driver_options,
)
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/ert/simulator/batch_simulator_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __post_init__(self) -> None:
"""
Handle which can be used to query status and results for batch simulation.
"""
driver = create_driver(self.queue_config)
driver = create_driver(self.queue_config.queue_options)
self._scheduler = Scheduler(driver, max_running=self.queue_config.max_running)

# fill in the missing geo_id data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_happy_path(
ensemble = make_ensemble(monkeypatch, tmpdir, 1, 1)

queue = Scheduler(
driver=create_driver(queue_config),
driver=create_driver(queue_config.queue_options),
realizations=ensemble.reals,
ens_id="ee_0",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/scheduler/test_lsf_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def test_submit_with_project_code():
"FORWARD_MODEL": [("FLOW",), ("ECLIPSE",), ("RMS",)],
}
queue_config = QueueConfig.from_dict(queue_config_dict)
driver: LsfDriver = create_driver(queue_config)
driver: LsfDriver = create_driver(queue_config.queue_options)
await driver.submit(0, "sleep")
assert f"-P {queue_config.queue_options.project_code}" in Path(
"captured_bsub_args"
Expand Down
4 changes: 2 additions & 2 deletions tests/ert/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def test_scheduler_create_lsf_driver():
],
}
queue_config = QueueConfig.from_dict(queue_config_dict)
driver = create_driver(queue_config)
driver = create_driver(queue_config.queue_options)
assert isinstance(driver, LsfDriver)
assert str(driver._bsub_cmd) == bsub_cmd
assert str(driver._bkill_cmd) == bkill_cmd
Expand Down Expand Up @@ -678,7 +678,7 @@ def test_scheduler_create_openpbs_driver():
],
}
queue_config = QueueConfig.from_dict(queue_config_dict)
driver = create_driver(queue_config)
driver = create_driver(queue_config.queue_options)
assert isinstance(driver, OpenPBSDriver)
assert driver._queue_name == queue_name
assert driver._keep_qsub_output == True if keep_qsub_output == "True" else False
Expand Down

0 comments on commit 1146983

Please sign in to comment.