Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store ert<->everest realization mapping #9767

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 123 additions & 121 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import queue
import shutil
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from enum import IntEnum
Expand All @@ -16,6 +15,7 @@
from typing import TYPE_CHECKING, Any, Protocol

import numpy as np
import polars
import seba_sqlite.sqlite_storage
from numpy import float64
from numpy._typing import NDArray
Expand All @@ -37,10 +37,12 @@
from everest.strings import EVEREST

from ..run_arg import RunArg, create_run_arguments
from ..storage.everest_ensemble import EverestRealizationInfo
from ..storage.everest_experiment import EverestExperiment
from .base_run_model import BaseRunModel, StatusEvents

if TYPE_CHECKING:
from ert.storage import Ensemble, Experiment
from ert.storage import Ensemble


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -122,15 +124,7 @@ def __init__(
self._fm_errors: dict[int, dict[str, Any]] = {}
self._result: OptimalResult | None = None
self._exit_code: EverestExitCode | None = None
self._simulator_cache = (
SimulatorCache()
if (
everest_config.simulator is not None
and everest_config.simulator.enable_cache
)
else None
)
self._experiment: Experiment | None = None
self._experiment: EverestExperiment | None = None
self._eval_server_cfg: EvaluatorServerConfig | None = None
self._batch_id: int = 0
self._status: SimulationStatus | None = None
Expand Down Expand Up @@ -197,11 +191,16 @@ def run_experiment(
) -> None:
self.log_at_startup()
self._eval_server_cfg = evaluator_server_config
self._experiment = self._storage.create_experiment(
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=self._parameter_configuration,
responses=self._response_configuration,
)

# Keep for re-runs, will work in-memory
# 2DO: If an experiment for this exact config already exists,
# we should load the experiment from the file system
if self._experiment is None:
self._experiment = self._storage.create_everest_experiment(
parameters=self._parameter_configuration,
responses=self._response_configuration,
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
)

# Initialize the ropt optimizer:
optimizer = self._create_optimizer()
Expand Down Expand Up @@ -371,43 +370,74 @@ def _forward_model_evaluator(

# Initialize a new ensemble in storage:
assert self._experiment is not None
ensemble = self._experiment.create_ensemble(
everest_ensemble = self._experiment.create_ensemble(
name=f"batch_{self._batch_id}",
ensemble_size=len(batch_data),
)
ert_ensemble = everest_ensemble.ert_ensemble

realizations = self._everest_config.model.realizations
num_perturbations = (
1
if self._everest_config.optimization is None
else self._everest_config.optimization.perturbation_num or 1
)

realization_mapping: dict[int, EverestRealizationInfo] = {}
if len(evaluator_context.realizations) == len(realizations):
# Function evaluation
realization_mapping = {
i: EverestRealizationInfo(geo_realization=real, perturbation=None)
for i, real in enumerate(realizations)
}
elif len(evaluator_context.realizations) == num_perturbations:
realization_mapping = {
p: EverestRealizationInfo(geo_realization=real, perturbation=p)
for p, real in enumerate(realizations)
}
else:
# Function and gradient
for i, real in enumerate(realizations):
realization_mapping[i] = EverestRealizationInfo(
geo_realization=real, perturbation=None
)

i = len(realization_mapping)
for real in realizations:
for p in range(num_perturbations or 1):
realization_mapping[i] = EverestRealizationInfo(
geo_realization=real, perturbation=p
)
i += 1

# Fill in data from ROPT here
everest_ensemble.save_realization_mapping(
realization_mapping=realization_mapping
)
for sim_id, controls in enumerate(batch_data.values()):
self._setup_sim(sim_id, controls, ensemble)
self._setup_sim(sim_id, controls, ert_ensemble)

# Evaluate the batch:
run_args = self._get_run_args(ensemble, evaluator_context, batch_data)
run_args = self._get_run_args(ert_ensemble, evaluator_context, batch_data)
self._context_env.update(
{
"_ERT_EXPERIMENT_ID": str(ensemble.experiment_id),
"_ERT_ENSEMBLE_ID": str(ensemble.id),
"_ERT_EXPERIMENT_ID": str(ert_ensemble.experiment_id),
"_ERT_ENSEMBLE_ID": str(ert_ensemble.id),
"_ERT_SIMULATION_MODE": "batch_simulation",
}
)
assert self._eval_server_cfg is not None
self._evaluate_and_postprocess(run_args, ensemble, self._eval_server_cfg)
self._evaluate_and_postprocess(run_args, ert_ensemble, self._eval_server_cfg)

# If necessary, delete the run path:
self._delete_runpath(run_args)

# Gather the results and create the result for ropt:
results = self._gather_simulation_results(ensemble)
results = self._gather_simulation_results(ert_ensemble)
evaluator_result = self._make_evaluator_result(
control_values, batch_data, results, cached_results
)

# Add the results from the evaluations to the cache:
self._add_results_to_cache(
control_values,
evaluator_context,
batch_data,
evaluator_result.objectives,
evaluator_result.constraints,
)

# Increase the batch ID for the next evaluation:
self._batch_id += 1

Expand All @@ -416,16 +446,60 @@ def _forward_model_evaluator(
def _get_cached_results(
self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext
) -> dict[int, Any]:
cached_results: dict[int, Any] = {}
if self._simulator_cache is not None:
for control_idx, real_idx in enumerate(evaluator_context.realizations):
cached_data = self._simulator_cache.get(
self._everest_config.model.realizations[real_idx],
control_values[control_idx, :],
assert self._experiment is not None

control_groups = {c.name: c for c in self._everest_config.controls}
control_variables = {g: len(c.variables) for g, c in control_groups.items()}
control_group_spans: list[tuple[int, int]] = []
span_: int = 0
for num_vars in control_variables.values():
control_group_spans.append((span_, span_ + int(num_vars)))
span_ += int(num_vars)

def controls_1d_to_dict(values_: list[float]) -> dict[str, list[float]]:
return {
group: values_[span_[0] : span_[1]]
for group, span_ in zip(
control_groups, control_group_spans, strict=False
)
if cached_data is not None:
cached_results[control_idx] = cached_data
return cached_results
}

cached_results2: dict[int, Any] = {}
for control_values_ in control_values.tolist():
parameter_group_values = controls_1d_to_dict(control_values_)

# If several realizations have the same controls,
# but different responses, make sure to not get stuck
# on the first found realization.
ert_realization, matching_ensemble = (
self._experiment.find_realization_with_data(
parameter_group_values, exclude=cached_results2.keys()
)
)
if ert_realization is not None:
assert matching_ensemble is not None
responses = matching_ensemble.ert_ensemble.load_responses(
"gen_data", (ert_realization,)
)
objectives = responses.filter(
polars.col("response_key").is_in(
self._everest_config.objective_names
)
)["values"]

constraints = responses.filter(
polars.col("response_key").is_in(
self._everest_config.constraint_names
)
)["values"]

cached_data = (
objectives.to_numpy() * -1,
constraints.to_numpy() if not constraints.is_empty() else None,
)
cached_results2[ert_realization] = cached_data

return cached_results2

def _init_batch_data(
self,
Expand Down Expand Up @@ -609,15 +683,14 @@ def _make_evaluator_result(
batch_data,
)

if self._simulator_cache is not None:
for control_idx, (
cached_objectives,
cached_constraints,
) in cached_results.items():
objectives[control_idx, ...] = cached_objectives
if constraints is not None:
assert cached_constraints is not None
constraints[control_idx, ...] = cached_constraints
for control_idx, (
cached_objectives,
cached_constraints,
) in cached_results.items():
objectives[control_idx, ...] = cached_objectives
if constraints is not None:
assert cached_constraints is not None
constraints[control_idx, ...] = cached_constraints

sim_ids = np.full(control_values.shape[0], -1, dtype=np.intc)
sim_ids[list(batch_data.keys())] = np.arange(len(batch_data), dtype=np.intc)
Expand All @@ -644,25 +717,6 @@ def _get_simulation_results(
)
return values

def _add_results_to_cache(
self,
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
batch_data: dict[int, Any],
objectives: NDArray[np.float64],
constraints: NDArray[np.float64] | None,
) -> None:
if self._simulator_cache is not None:
for control_idx in batch_data:
self._simulator_cache.add(
self._everest_config.model.realizations[
evaluator_context.realizations[control_idx]
],
control_values[control_idx, ...],
objectives[control_idx, ...],
None if constraints is None else constraints[control_idx, ...],
)

def check_if_runpath_exists(self) -> bool:
return (
self._everest_config.simulation_dir is not None
Expand Down Expand Up @@ -744,55 +798,3 @@ def _handle_errors(
elif fm_id not in self._fm_errors[error_hash]["ids"]:
self._fm_errors[error_hash]["ids"].append(fm_id)
error_id = self._fm_errors[error_hash]["error_id"]
fm_logger.error(err_msg.format("Already reported as", error_id))


class SimulatorCache:
EPS = float(np.finfo(np.float32).eps)

def __init__(self) -> None:
self._data: defaultdict[
int,
list[
tuple[
NDArray[np.float64], NDArray[np.float64], NDArray[np.float64] | None
]
],
] = defaultdict(list)

def add(
self,
realization: int,
control_values: NDArray[np.float64],
objectives: NDArray[np.float64],
constraints: NDArray[np.float64] | None,
) -> None:
"""Add objective and constraints for a given realization and control values.

The realization is the index of the realization in the ensemble, as specified
in by the realizations entry in the everest model configuration. Both the control
values and the realization are used as keys to retrieve the objectives and
constraints later.
"""
self._data[realization].append(
(
control_values.copy(),
objectives.copy(),
None if constraints is None else constraints.copy(),
),
)

def get(
self, realization: int, controls: NDArray[np.float64]
) -> tuple[NDArray[np.float64], NDArray[np.float64] | None] | None:
"""Get objective and constraints for a given realization and control values.

The realization is the index of the realization in the ensemble, as specified
in by the realizations entry in the everest model configuration. Both the control
values and the realization are used as keys to retrieve the objectives and
constraints from the cached values.
"""
for control_values, objectives, constraints in self._data.get(realization, []):
if np.allclose(controls, control_values, rtol=0.0, atol=self.EPS):
return objectives, constraints
return None
37 changes: 37 additions & 0 deletions src/ert/storage/everest_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pydantic import BaseModel
from typing_extensions import TypedDict

from .local_ensemble import LocalEnsemble


class EverestRealizationInfo(TypedDict):
geo_realization: int
perturbation: int | None # None means it stems from unperturbed controls
# Q: Maybe we also need result ID, or no? Ref if we have multiple evaluations
# for unperturbed values, though the ERT real id will also differentiate them


class _Index(BaseModel):
ert2ev_realization_mapping: dict[int, EverestRealizationInfo] | None = None


class EverestEnsemble:
def __init__(self, ert_ensemble: LocalEnsemble):
self._ert_ensemble = ert_ensemble
self._index = _Index()

@property
def ert_ensemble(self) -> LocalEnsemble:
return self._ert_ensemble

def save_realization_mapping(
self, realization_mapping: dict[int, EverestRealizationInfo]
) -> None:
self._index.ert2ev_realization_mapping = realization_mapping
self._ert_ensemble._storage._write_transaction(
self.ert_ensemble._path / "everest_index.json",
self._index.model_dump_json().encode("utf-8"),
)
self._index = _Index.model_validate_json(
(self.ert_ensemble._path / "everest_index.json").read_text(encoding="utf-8")
)
Loading
Loading