Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make QueueConfig serializable #9143

Merged
merged 8 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import re
import shutil
from abc import abstractmethod
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, List, Mapping, Optional, no_type_check
from dataclasses import asdict, field, fields
from typing import Any, Dict, List, Literal, Mapping, Optional, Union, no_type_check

import pydantic
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import Annotated

from .parsing import (
Expand All @@ -27,6 +29,7 @@

@pydantic.dataclasses.dataclass(config={"extra": "forbid", "validate_assignment": True})
class QueueOptions:
name: str
max_running: pydantic.NonNegativeInt = 0
submit_sleep: pydantic.NonNegativeFloat = 0.0
project_code: Optional[str] = None
Expand Down Expand Up @@ -79,13 +82,16 @@ def driver_options(self) -> Dict[str, Any]:

@pydantic.dataclasses.dataclass
class LocalQueueOptions(QueueOptions):
name: Literal[QueueSystem.LOCAL] = QueueSystem.LOCAL

@property
def driver_options(self) -> Dict[str, Any]:
return {}


@pydantic.dataclasses.dataclass
class LsfQueueOptions(QueueOptions):
name: Literal[QueueSystem.LSF] = QueueSystem.LSF
bhist_cmd: Optional[NonEmptyString] = None
bjobs_cmd: Optional[NonEmptyString] = None
bkill_cmd: Optional[NonEmptyString] = None
Expand All @@ -97,6 +103,7 @@ class LsfQueueOptions(QueueOptions):
@property
def driver_options(self) -> Dict[str, Any]:
driver_dict = asdict(self)
driver_dict.pop("name")
driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host")
driver_dict["queue_name"] = driver_dict.pop("lsf_queue")
driver_dict["resource_requirement"] = driver_dict.pop("lsf_resource")
Expand All @@ -107,6 +114,7 @@ def driver_options(self) -> Dict[str, Any]:

@pydantic.dataclasses.dataclass
class TorqueQueueOptions(QueueOptions):
name: Literal[QueueSystem.TORQUE] = QueueSystem.TORQUE
qsub_cmd: Optional[NonEmptyString] = None
qstat_cmd: Optional[NonEmptyString] = None
qdel_cmd: Optional[NonEmptyString] = None
Expand All @@ -124,6 +132,7 @@ class TorqueQueueOptions(QueueOptions):
@property
def driver_options(self) -> Dict[str, Any]:
driver_dict = asdict(self)
driver_dict.pop("name")
driver_dict["queue_name"] = driver_dict.pop("queue")
driver_dict.pop("max_running")
driver_dict.pop("submit_sleep")
Expand All @@ -133,21 +142,22 @@ def driver_options(self) -> Dict[str, Any]:

@pydantic.field_validator("memory_per_job")
@classmethod
def check_memory_per_job(cls, value: str) -> str:
def check_memory_per_job(cls, value: Optional[str]) -> Optional[str]:
if not queue_memory_usage_formats[QueueSystem.TORQUE].validate(value):
raise ValueError("wrong memory format")
return value


@pydantic.dataclasses.dataclass
class SlurmQueueOptions(QueueOptions):
name: Literal[QueueSystem.SLURM] = QueueSystem.SLURM
sbatch: NonEmptyString = "sbatch"
scancel: NonEmptyString = "scancel"
scontrol: NonEmptyString = "scontrol"
squeue: NonEmptyString = "squeue"
exclude_host: str = ""
include_host: str = ""
memory: str = ""
memory: Optional[NonEmptyString] = None
memory_per_cpu: Optional[NonEmptyString] = None
partition: Optional[NonEmptyString] = None # aka queue_name
squeue_timeout: pydantic.PositiveFloat = 2
Expand All @@ -156,6 +166,7 @@ class SlurmQueueOptions(QueueOptions):
@property
def driver_options(self) -> Dict[str, Any]:
driver_dict = asdict(self)
driver_dict.pop("name")
driver_dict["sbatch_cmd"] = driver_dict.pop("sbatch")
driver_dict["scancel_cmd"] = driver_dict.pop("scancel")
driver_dict["scontrol_cmd"] = driver_dict.pop("scontrol")
Expand All @@ -169,7 +180,7 @@ def driver_options(self) -> Dict[str, Any]:

@pydantic.field_validator("memory", "memory_per_cpu")
@classmethod
def check_memory_per_job(cls, value: str) -> str:
def check_memory_per_job(cls, value: Optional[str]) -> Optional[str]:
if not queue_memory_usage_formats[QueueSystem.SLURM].validate(value):
raise ValueError("wrong memory format")
return value
Expand All @@ -179,7 +190,9 @@ def check_memory_per_job(cls, value: str) -> str:
class QueueMemoryStringFormat:
suffixes: List[str]

def validate(self, mem_str_format: str) -> bool:
def validate(self, mem_str_format: Optional[str]) -> bool:
if mem_str_format is None:
return True
return (
re.match(
r"\d+(" + "|".join(self.suffixes) + ")$",
Expand Down Expand Up @@ -255,8 +268,10 @@ class QueueConfig:
realization_memory: int = 0
max_submit: int = 1
queue_system: QueueSystem = QueueSystem.LOCAL
queue_options: QueueOptions = field(default_factory=QueueOptions)
queue_options_test_run: QueueOptions = field(default_factory=LocalQueueOptions)
queue_options: Union[
LsfQueueOptions, TorqueQueueOptions, SlurmQueueOptions, LocalQueueOptions
] = Field(default_factory=LocalQueueOptions, discriminator="name")
queue_options_test_run: LocalQueueOptions = field(default_factory=LocalQueueOptions)
stop_long_running: bool = False

@no_type_check
Expand Down Expand Up @@ -349,7 +364,7 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig:
selected_queue_system,
queue_options,
queue_options_test_run,
stop_long_running=stop_long_running,
stop_long_running=bool(stop_long_running),
)

def create_local_copy(self) -> QueueConfig:
Expand All @@ -360,7 +375,7 @@ def create_local_copy(self) -> QueueConfig:
QueueSystem.LOCAL,
self.queue_options_test_run,
self.queue_options_test_run,
stop_long_running=self.stop_long_running,
stop_long_running=bool(self.stop_long_running),
)

@property
Expand Down
2 changes: 1 addition & 1 deletion src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
raise ValueError("no config") # mypy

try:
driver = create_driver(self._queue_config)
driver = create_driver(self._queue_config.queue_options)
self._scheduler = Scheduler(
driver,
self.active_reals,
Expand Down
20 changes: 10 additions & 10 deletions src/ert/scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@
from .slurm_driver import SlurmDriver

if TYPE_CHECKING:
from ert.config.queue_config import QueueConfig
from ert.config.queue_config import QueueOptions


def create_driver(config: QueueConfig) -> Driver:
if config.queue_system == QueueSystem.LOCAL:
return LocalDriver(**config.queue_options.driver_options)
elif config.queue_system == QueueSystem.TORQUE:
return OpenPBSDriver(**config.queue_options.driver_options)
elif config.queue_system == QueueSystem.LSF:
return LsfDriver(**config.queue_options.driver_options)
elif config.queue_system == QueueSystem.SLURM:
def create_driver(queue_options: QueueOptions) -> Driver:
if queue_options.name == QueueSystem.LOCAL:
return LocalDriver()
elif queue_options.name == QueueSystem.TORQUE:
return OpenPBSDriver(**queue_options.driver_options)
elif queue_options.name == QueueSystem.LSF:
return LsfDriver(**queue_options.driver_options)
elif queue_options.name == QueueSystem.SLURM:
return SlurmDriver(
**dict(
{"user": getpwuid(getuid()).pw_name},
**config.queue_options.driver_options,
**queue_options.driver_options,
)
)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/ert/simulator/batch_simulator_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __post_init__(self) -> None:
"""
Handle which can be used to query status and results for batch simulation.
"""
driver = create_driver(self.queue_config)
driver = create_driver(self.queue_config.queue_options)
self._scheduler = Scheduler(driver, max_running=self.queue_config.max_running)

# fill in the missing geo_id data
Expand Down
21 changes: 8 additions & 13 deletions src/everest/bin/everest_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,11 @@ def _build_args_parser():
return arg_parser


def _run_everest(options, ert_config, storage):
with PluginSiteConfigEnv():
context = start_server(options.config, ert_config, storage)
print("Waiting for server ...")
wait_for_server(options.config, timeout=600, context=context)
print("Everest server found!")
run_detached_monitor(options.config, show_all_jobs=options.show_all_jobs)
wait_for_context()


def run_everest(options):
logger = logging.getLogger("everest_main")
server_state = everserver_status(options.config)

if server_is_running(options.config):
if server_is_running(*options.config.server_context):
config_file = options.config.config_file
print(
"An optimization is currently running.\n"
Expand All @@ -119,8 +109,13 @@ def run_everest(options):

makedirs_if_needed(options.config.output_dir, roll_if_exists=True)

with open_storage(ert_config.ens_path, "w") as storage:
_run_everest(options, ert_config, storage)
with open_storage(ert_config.ens_path, "w") as storage, PluginSiteConfigEnv():
context = start_server(options.config, ert_config, storage)
print("Waiting for server ...")
wait_for_server(options.config, timeout=600, context=context)
print("Everest server found!")
run_detached_monitor(options.config, show_all_jobs=options.show_all_jobs)
wait_for_context()

server_state = everserver_status(options.config)
server_state_info = server_state["message"]
Expand Down
2 changes: 1 addition & 1 deletion src/everest/bin/kill_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _handle_keyboard_interrupt(signal, frame, after=False):


def kill_everest(options):
if not server_is_running(options.config):
if not server_is_running(*options.config.server_context):
print("Server is not running.")
return

Expand Down
2 changes: 1 addition & 1 deletion src/everest/bin/monitor_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def monitor_everest(options):
config: EverestConfig = options.config
server_state = everserver_status(options.config)

if server_is_running(config):
if server_is_running(*config.server_context):
run_detached_monitor(config, show_all_jobs=options.show_all_jobs)
server_state = everserver_status(config)
if server_state["status"] == ServerStatus.failed:
Expand Down
30 changes: 20 additions & 10 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from argparse import ArgumentParser
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING, List, Literal, Optional, Protocol, no_type_check
from typing import (
TYPE_CHECKING,
List,
Literal,
Optional,
Protocol,
Tuple,
no_type_check,
)

from pydantic import (
AfterValidator,
Expand Down Expand Up @@ -684,18 +692,20 @@ def hostfile_path(self):
def server_info(self):
"""Load server information from the hostfile"""
host_file_path = self.hostfile_path
try:
with open(host_file_path, "r", encoding="utf-8") as f:
json_string = f.read()

with open(host_file_path, "r", encoding="utf-8") as f:
json_string = f.read()

data = json.loads(json_string)
if set(data.keys()) != {"host", "port", "cert", "auth"}:
raise RuntimeError("Malformed hostfile")

return data
data = json.loads(json_string)
if set(data.keys()) != {"host", "port", "cert", "auth"}:
raise RuntimeError("Malformed hostfile")
return data
except FileNotFoundError:
# No host file
return {"host": None, "port": None, "cert": None, "auth": None}

@property
def server_context(self):
def server_context(self) -> Tuple[str, str, Tuple[str, str]]:
"""Returns a tuple with
- url of the server
- path to the .cert file
Expand Down
Loading