Skip to content

Commit

Permalink
Store everest realization mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Jan 24, 2025
1 parent 0d61c06 commit 59b19b9
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 5 deletions.
87 changes: 82 additions & 5 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from everest.strings import EVEREST

from ..run_arg import RunArg, create_run_arguments
from ..storage.everest_ensemble import EverestRealizationInfo
from ..storage.everest_experiment import EverestExperiment
from .base_run_model import BaseRunModel, StatusEvents

Expand Down Expand Up @@ -199,11 +200,15 @@ def run_experiment(
self.log_at_startup()
self._eval_server_cfg = evaluator_server_config

self._experiment = self._storage.create_everest_experiment(
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=self._parameter_configuration,
responses=self._response_configuration,
)
# Keep for re-runs, will work in-memory
# 2DO: If an experiment for this exact config already exists,
# we should load the experiment from the file system
if self._experiment is None:
self._experiment = self._storage.create_everest_experiment(
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=self._parameter_configuration,
responses=self._response_configuration,
)

# Initialize the ropt optimizer:
optimizer = self._create_optimizer()
Expand Down Expand Up @@ -378,6 +383,42 @@ def _forward_model_evaluator(
ensemble_size=len(batch_data),
)
ert_ensemble = everest_ensemble.ert_ensemble

realizations = self._everest_config.model.realizations
num_perturbations = self._everest_config.optimization.perturbation_num

Check failure on line 388 in src/ert/run_models/everest_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Item "None" of "OptimizationConfig | None" has no attribute "perturbation_num"
realization_mapping: dict[int, EverestRealizationInfo] = {}

if len(evaluator_context.realizations) == len(realizations):
# Function evaluation
realization_mapping = {
i: EverestRealizationInfo(geo_realization=real, perturbation=None)
for i, real in enumerate(realizations)
}
elif len(evaluator_context.realizations) == num_perturbations:
realization_mapping = {
p: EverestRealizationInfo(geo_realization=real, perturbation=p)
for p, real in enumerate(realizations)
}
else:
# Function and gradient
realization_mapping = {}
for i, real in enumerate(realizations):
realization_mapping[i] = EverestRealizationInfo(
geo_realization=real, perturbation=None
)

i = len(realization_mapping)
for real in realizations:
for p in range(num_perturbations or 1):
realization_mapping[i] = EverestRealizationInfo(
geo_realization=real, perturbation=p
)
i += 1

# Fill in data from ROPT here
everest_ensemble.save_realization_mapping(
realization_mapping=realization_mapping
)
for sim_id, controls in enumerate(batch_data.values()):
self._setup_sim(sim_id, controls, ert_ensemble)

Expand Down Expand Up @@ -419,6 +460,22 @@ def _forward_model_evaluator(
def _get_cached_results(
self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext
) -> dict[int, Any]:
control_groups = {c.name: c for c in self._everest_config.controls}
control_variables = {g: len(c.variables) for g, c in control_groups.items()}
control_group_spans = []
span_ = 0
for num_vars in control_variables.values():
control_group_spans.append((span_, span_ + num_vars))
span_ += num_vars

def controls_1d_to_dict(values_: list[float]):

Check failure on line 471 in src/ert/run_models/everest_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
return {
group: values_[span_[0] : span_[1]]
for group, span_ in zip(
control_groups, control_group_spans, strict=False
)
}

cached_results: dict[int, Any] = {}
if self._simulator_cache is not None:
for control_idx, real_idx in enumerate(evaluator_context.realizations):
Expand All @@ -428,6 +485,26 @@ def _get_cached_results(
)
if cached_data is not None:
cached_results[control_idx] = cached_data

cached_results2: dict[int, Any] = {}
for control_idx, control_values_ in enumerate(control_values.tolist()):
parameter_group_values = controls_1d_to_dict(control_values_)
matching_realization, matching_ensemble = (
self._experiment.find_realization_by_parameter_values(

Check failure on line 493 in src/ert/run_models/everest_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Item "None" of "EverestExperiment | None" has no attribute "find_realization_by_parameter_values"
parameter_group_values
)
)
if matching_realization is not None:
assert matching_ensemble is not None
cached_data = matching_ensemble.ert_ensemble.load_responses(

Check failure on line 499 in src/ert/run_models/everest_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Incompatible types in assignment (expression has type "ndarray[Any, Any] | Any", variable has type "tuple[ndarray[Any, dtype[floating[_64Bit]]], ndarray[Any, dtype[floating[_64Bit]]] | None] | None")
"gen_data", (matching_realization,)
)["values"].to_numpy()
print(control_idx)
# 2do

if cached_results2 != cached_results:
print("FAIL")

return cached_results

def _init_batch_data(
Expand Down
27 changes: 27 additions & 0 deletions src/ert/storage/everest_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
from pydantic import BaseModel
from typing_extensions import TypedDict

from .local_ensemble import LocalEnsemble


class EverestRealizationInfo(TypedDict):
geo_realization: int
perturbation: int | None # None means it stems from unperturbed controls
# Q: Maybe we also need result ID, or no? Ref if we have multiple evaluations
# for unperturbed values, though the ERT real id will also differentiate them


class _Index(BaseModel):
ert2ev_realization_mapping: dict[int, EverestRealizationInfo] | None = None


class EverestEnsemble:
def __init__(self, ert_ensemble: LocalEnsemble):
self._ert_ensemble = ert_ensemble
self._index = _Index()

@property
def ert_ensemble(self) -> LocalEnsemble:
return self._ert_ensemble

def save_realization_mapping(

Check failure on line 27 in src/ert/storage/everest_ensemble.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self, realization_mapping: dict[int, EverestRealizationInfo]
):
self._index.ert2ev_realization_mapping = realization_mapping
self._ert_ensemble._storage._write_transaction(
self.ert_ensemble._path / "everest_index.json",
self._index.model_dump_json().encode("utf-8"),
)
self._index = _Index.model_validate_json(
(self.ert_ensemble._path / "everest_index.json").read_text(encoding="utf-8")
)
34 changes: 34 additions & 0 deletions src/ert/storage/everest_experiment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import cached_property

import numpy as np

from ert.storage import LocalExperiment

from .everest_ensemble import EverestEnsemble
Expand Down Expand Up @@ -28,4 +30,36 @@ def create_ensemble(
everest_ensemble = EverestEnsemble(ert_ensemble)
if self.ensembles is not None:
del self.ensembles # Clear cache when a new ensemble is created

return everest_ensemble

def find_realization_by_parameter_values(
self, parameter_values: dict[str, np.array]

Check failure on line 37 in src/ert/storage/everest_experiment.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function "numpy.core.multiarray.array" is not valid as a type
) -> tuple[int, EverestEnsemble] | tuple[None, None]:
if not list(self.ensembles):
return None, None

for e in self.ensembles:
ens_parameters = {
group: e.ert_ensemble.load_parameters(group)
.to_dataarray()
.data.reshape((e.ert_ensemble.ensemble_size, -1))
for group in parameter_values
}

matching_real = next(
(
i
for i in range(e.ert_ensemble.ensemble_size)
if all(
np.allclose(ens_parameters[group][i], group_data)
for group, group_data in parameter_values.items()
)
),
None,
)

if matching_real is not None:
return matching_real, e

return None, None
31 changes: 31 additions & 0 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,34 @@ def _update_response_keys(

if self.response_type_to_response_keys is not None:
del self.response_type_to_response_keys

def find_realization_by_parameter_values(
self, parameter_values: dict[str, np.array]

Check failure on line 395 in src/ert/storage/local_experiment.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function "numpy.core.multiarray.array" is not valid as a type
) -> int | None:
if not list(self.ensembles):
return None

for ensemble in self.ensembles:
ens_parameters = {
group: ensemble.load_parameters(group)
.to_dataarray()
.data.reshape((ensemble.ensemble_size, -1))
for group in parameter_values
}

matching_real = next(
(
i
for i in range(ensemble.ensemble_size)
if all(
np.allclose(ens_parameters[group][i], group_data, atol=1e-8)
for group, group_data in parameter_values.items()
)
),
None,
)

if matching_real is not None:
return matching_real

return None

0 comments on commit 59b19b9

Please sign in to comment.