Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where site config was not propagated to Everest config #9719

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ert/config/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .analysis_mode import AnalysisMode
from .base_model_context import BaseModelWithContextSupport
from .config_dict import ConfigDict
from .config_errors import ConfigValidationError, ConfigWarning
from .config_keywords import ConfigKeys
Expand All @@ -20,6 +21,7 @@

__all__ = [
"AnalysisMode",
"BaseModelWithContextSupport",
"ConfigDict",
"ConfigKeys",
"ConfigValidationError",
Expand Down
26 changes: 26 additions & 0 deletions src/ert/config/parsing/base_model_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any

from pydantic import BaseModel

init_context_var = ContextVar("_init_context_var", default=None)


@contextmanager
def init_context(value: dict[str, Any]) -> Iterator[None]:
token = init_context_var.set(value) # type: ignore
try:
yield
finally:
init_context_var.reset(token)


class BaseModelWithContextSupport(BaseModel):
def __init__(__pydantic_self__, **data: Any) -> None:
__pydantic_self__.__pydantic_validator__.validate_python(
data,
self_instance=__pydantic_self__,
context=init_context_var.get(),
)
20 changes: 17 additions & 3 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from typing import Annotated, Any, Literal, no_type_check

import pydantic
from pydantic import BaseModel, Field
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from pydantic_core.core_schema import ValidationInfo

from ._get_num_cpu import get_num_cpu_from_data_file
from .parsing import (
BaseModelWithContextSupport,
ConfigDict,
ConfigKeys,
ConfigValidationError,
Expand All @@ -38,7 +40,7 @@ def activate_script() -> str:


class QueueOptions(
BaseModel,
BaseModelWithContextSupport,
validate_assignment=True,
extra="forbid",
use_enum_values=True,
Expand All @@ -48,7 +50,19 @@ class QueueOptions(
max_running: pydantic.NonNegativeInt = 0
submit_sleep: pydantic.NonNegativeFloat = 0.0
project_code: str | None = None
activate_script: str = Field(default_factory=activate_script)
activate_script: str | None = Field(default=None, validate_default=True)

@field_validator("activate_script", mode="before")
@classmethod
def inject_site_config_script(cls, v: str, info: ValidationInfo) -> str:
# User value gets highest priority
if isinstance(v, str):
return v
# Use from plugin system if user has not specified
plugin_script = None
if info.context:
plugin_script = info.context.get(info.field_name)
Copy link
Contributor

@yngve-sk yngve-sk Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test that runs this line? EDIT: I see running test_detached triggers it, but do any of the ERT tests run it? Is this mainly meant for things we run via Everest?

return plugin_script or activate_script() # Return default value

@staticmethod
def create_queue_options(
Expand Down
14 changes: 13 additions & 1 deletion src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from ruamel.yaml import YAML, YAMLError

from ert.config import ErtConfig
from ert.config.parsing import BaseModelWithContextSupport
from ert.config.parsing.base_model_context import init_context
from ert.plugins import ErtPluginManager
from everest.config.control_variable_config import ControlVariableGuessListConfig
from everest.config.install_template_config import InstallTemplateConfig
from everest.config.server_config import ServerConfig
Expand Down Expand Up @@ -134,7 +137,7 @@ class HasName(Protocol):
name: str


class EverestConfig(BaseModelWithPropertySupport): # type: ignore
class EverestConfig(BaseModelWithPropertySupport, BaseModelWithContextSupport): # type: ignore
controls: Annotated[list[ControlConfig], AfterValidator(unique_items)] = Field(
description="""Defines a list of controls.
Controls should have unique names each control defines
Expand Down Expand Up @@ -807,6 +810,15 @@ def load_file(config_file: str) -> "EverestConfig":

raise exp from error

@classmethod
def with_plugins(cls, config_dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused as to why mypy is not complaining about the missing type for the argument config_dict shouldn't there be a : dict[str, Any] or something similar?

context = {}
activate_script = ErtPluginManager().activate_script()
if activate_script:
context["activate_script"] = ErtPluginManager().activate_script()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just have

context["activate_script"] = activate_script

in place of

context["activate_script"] = ErtPluginManager().activate_script()

Copy link
Contributor

@yngve-sk yngve-sk Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to re-invoke ErtPluginManager().activate_script() after having stored it in a variable?

with init_context(context):
return cls(**config_dict)

@staticmethod
def load_file_with_argparser(
config_path, parser: ArgumentParser
Expand Down
12 changes: 1 addition & 11 deletions src/everest/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os
from typing import Any

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator

from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager

from ..strings import (
CERTIFICATE_DIR,
Expand Down Expand Up @@ -38,15 +37,6 @@ class ServerConfig(BaseModel): # type: ignore
extra="forbid",
)

@field_validator("queue_system", mode="before")
@classmethod
def default_local_queue(cls, v):
if v is None:
return v
elif "activate_script" not in v and ErtPluginManager().activate_script():
v["activate_script"] = ErtPluginManager().activate_script()
return v

@model_validator(mode="before")
@classmethod
def check_old_config(cls, data: Any) -> Any:
Expand Down
5 changes: 0 additions & 5 deletions src/everest/config/simulator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager

simulator_example = {"queue_system": {"name": "local", "max_running": 3}}

Expand Down Expand Up @@ -97,10 +96,6 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore
def default_local_queue(cls, v):
if v is None:
return LocalQueueOptions(max_running=8)
if "activate_script" not in v and (
active_script := ErtPluginManager().activate_script()
):
v["activate_script"] = active_script
return v

@model_validator(mode="before")
Expand Down
31 changes: 30 additions & 1 deletion tests/everest/test_detached.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_generate_queue_options_use_simulator_values(
queue_options, expected_result, monkeypatch
):
monkeypatch.setattr(
everest.config.server_config.ErtPluginManager,
everest.config.everest_config.ErtPluginManager,
"activate_script",
MagicMock(return_value=activate_script()),
)
Expand All @@ -295,6 +295,35 @@ def test_generate_queue_options_use_simulator_values(
assert config.server.queue_system == expected_result


@pytest.mark.parametrize("use_plugin", (True, False))
@pytest.mark.parametrize(
"queue_options",
[
{"name": "slurm", "activate_script": "From user"},
{"name": "slurm"},
],
)
def test_queue_options_site_config(queue_options, use_plugin, monkeypatch, min_config):
plugin_result = "From plugin"
if "activate_script" in queue_options:
expected_result = queue_options["activate_script"]
elif use_plugin:
expected_result = plugin_result
else:
expected_result = activate_script()

if use_plugin:
monkeypatch.setattr(
everest.config.everest_config.ErtPluginManager,
"activate_script",
MagicMock(return_value=plugin_result),
)
config = EverestConfig.with_plugins(
{"simulator": {"queue_system": queue_options}} | min_config
)
assert config.server.queue_system.activate_script == expected_result


@pytest.mark.timeout(5) # Simulation might not finish
@pytest.mark.integration_test
@pytest.mark.xdist_group(name="starts_everest")
Expand Down