Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the ert plugin mechanism to install forward models #75

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ docs = [
[project.entry-points."everest"]
everest-models = "everest_models.everest_hooks"

[project.entry-points."ert"]
everest_models_forward_models = "everest_models.forward_models"

[project.scripts]
fm_add_templates = "everest_models.jobs.fm_add_templates.cli:main_entry_point"
fm_drill_date_planner = "everest_models.jobs.fm_drill_date_planner.cli:main_entry_point"
Expand Down
2 changes: 0 additions & 2 deletions src/everest_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import pathlib

from everest_models.everest_hooks import (
get_forward_models,
get_forward_models_schemas,
parse_forward_model_schema,
)
from everest_models.logger import set_up_logger

__all__ = [
"get_forward_models",
"get_forward_models_schemas",
"parse_forward_model_schema",
]
Expand Down
32 changes: 5 additions & 27 deletions src/everest_models/everest_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import logging
import pathlib
import sys
from importlib import import_module, resources
from typing import Any, Dict, List, Sequence, Type

Expand All @@ -24,36 +23,13 @@

logger = logging.getLogger(__name__)

FORWARD_MODEL_DIR = "forward_models"
PACKAGE = "everest_models"
JOBS = f"{PACKAGE}.jobs"
JOBS = "everest_models.jobs"


def _get_jobs():
return (job for job in resources.contents(JOBS) if job.startswith("fm_"))


@hookimpl
def get_forward_models() -> List[Dict[str, str]]:
"""Accumulate all maintained forward model jobs by name and path.

Returns:
(List[Dict[str, str]]): list of forward models and corrolated path
- {name: forward_model, path: /path/to/forward_model}
- ...
"""
if sys.version_info.minor >= 9:
jobs = resources.files(PACKAGE) / FORWARD_MODEL_DIR # type: ignore
else:
with resources.path(PACKAGE, FORWARD_MODEL_DIR) as fd:
jobs = fd

return [
{"name": (job_name := job.lstrip("fm_")), "path": str(jobs / job_name)}
for job in _get_jobs()
]


@hookimpl
def get_forward_models_schemas() -> Dict[str, Dict[str, Type[BaseModel]]]:
"""Accumulate all forward model jobs and schemas.
Expand All @@ -75,7 +51,9 @@ def get_forward_models_schemas() -> Dict[str, Dict[str, Type[BaseModel]]]:
for job in _get_jobs():
schema = getattr(import_module(f"{JOBS}.{job}.parser"), "SCHEMAS", None)
if schema:
res[job.lstrip("fm_")] = schema.get("-c/--config") or schema.get("config")
res[job[3:] if job.startswith("fm_") else job] = schema.get(
"-c/--config"
) or schema.get("config")
return res


Expand Down Expand Up @@ -123,7 +101,7 @@ def get_forward_model_documentations() -> Dict[str, Any]:
import_module(f"{JOBS}.{job}.cli"), "FULL_JOB_NAME", cmd_name
)
examples = getattr(import_module(f"{JOBS}.{job}.cli"), "EXAMPLES", None)
docs[job.lstrip("fm_")] = {
docs[job[3:] if job.startswith("fm_") else job] = {
"cmd_name": cmd_name,
"examples": examples,
"full_job_name": full_job_name,
Expand Down
54 changes: 54 additions & 0 deletions src/everest_models/forward_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from importlib import resources
from importlib.util import find_spec
from typing import Final, List, Type

_HAVE_ERT: Final = find_spec("ert") is not None


def get_forward_models() -> List[str]:
"""Return the list of forward model names."""
return [
job[3:]
for job in resources.contents("everest_models.jobs")
if job.startswith("fm_")
]


if _HAVE_ERT: # The everest-models package should remain installable without ERT.
import ert
from ert import ForwardModelStepDocumentation, ForwardModelStepPlugin

def build_forward_model_step_plugin(
executable_name: str,
) -> Type[ForwardModelStepPlugin]:
forward_model_name = (
executable_name[3:]
if executable_name.startswith("fm_")
else executable_name
)
class_name = "".join(
x.capitalize() for x in forward_model_name.lower().split("_")
)
return type(
class_name,
(ForwardModelStepPlugin,),
{
"__init__": lambda x: ForwardModelStepPlugin.__init__(
x, name=forward_model_name, command=[executable_name]
),
"documentation": lambda: ForwardModelStepDocumentation(
category="everest.everest_models",
source_package="everest_models",
source_function_name=class_name,
description=f"The {forward_model_name} forward model.",
),
},
)

@ert.plugin(name="everest_models")
def installable_forward_model_steps():
return [
build_forward_model_step_plugin(job)
for job in resources.contents("everest_models.jobs")
if job.startswith("fm_")
]
1 change: 0 additions & 1 deletion src/everest_models/forward_models/add_templates

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/compute_economics

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/drill_date_planner

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/drill_planner

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/extract_summary_data

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/interpret_well_drill

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/npv

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/rf

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/schmerge

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/select_wells

This file was deleted.

4 changes: 0 additions & 4 deletions src/everest_models/forward_models/stea

This file was deleted.

3 changes: 0 additions & 3 deletions src/everest_models/forward_models/strip_dates

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/well_constraints

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/well_filter

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/well_swapping

This file was deleted.

1 change: 0 additions & 1 deletion src/everest_models/forward_models/well_trajectory

This file was deleted.

3 changes: 0 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ class TestSpec:

hookspec = pluggy.HookspecMarker("test")

@hookspec
def get_forward_models(self): ...

@hookspec
def get_forward_models_schemas(self): ...

Expand Down
28 changes: 0 additions & 28 deletions tests/integration/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import sys
from pathlib import Path

Expand All @@ -16,33 +15,6 @@ def test_hooks_registered(plugin_manager):
assert sys.modules["everest_models.everest_hooks"] in plugin_manager.get_plugins()


def test_get_forward_models_hook(plugin_manager):
jobs = {
"stea": f"{FORWARD_MODEL_DIR}/stea",
"drill_planner": f"{FORWARD_MODEL_DIR}/drill_planner",
"compute_economics": f"{FORWARD_MODEL_DIR}/compute_economics",
"schmerge": f"{FORWARD_MODEL_DIR}/schmerge",
"extract_summary_data": f"{FORWARD_MODEL_DIR}/extract_summary_data",
"drill_date_planner": f"{FORWARD_MODEL_DIR}/drill_date_planner",
"strip_dates": f"{FORWARD_MODEL_DIR}/strip_dates",
"select_wells": f"{FORWARD_MODEL_DIR}/select_wells",
"npv": f"{FORWARD_MODEL_DIR}/npv",
"well_constraints": f"{FORWARD_MODEL_DIR}/well_constraints",
"add_templates": f"{FORWARD_MODEL_DIR}/add_templates",
"rf": f"{FORWARD_MODEL_DIR}/rf",
"well_filter": f"{FORWARD_MODEL_DIR}/well_filter",
"interpret_well_drill": f"{FORWARD_MODEL_DIR}/interpret_well_drill",
"well_trajectory": f"{FORWARD_MODEL_DIR}/well_trajectory",
"well_swapping": f"{FORWARD_MODEL_DIR}/well_swapping",
}
assert all(
jobs[job["name"]] in job["path"]
for job in itertools.chain.from_iterable(
plugin_manager.hook.get_forward_models()
)
)


def test_get_forward_model_schemas_hook(plugin_manager):
assert not set(plugin_manager.hook.get_forward_models_schemas().pop()) - {
"add_templates",
Expand Down