Skip to content

Commit

Permalink
Merge new functionality from load_run_path_realization into forward_m…
Browse files Browse the repository at this point in the history
…odel_ok
  • Loading branch information
DanSava committed Nov 4, 2024
1 parent cbc1053 commit c4f70aa
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 70 deletions.
57 changes: 8 additions & 49 deletions src/ert/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pathlib import Path

from ert.config import InvalidResponseFile
from ert.run_arg import RunArg
from ert.storage import Ensemble
from ert.storage.realization_storage_state import RealizationStorageState

Expand Down Expand Up @@ -92,63 +91,23 @@ async def _write_responses_to_storage(


async def forward_model_ok(
run_arg: RunArg,
run_path: str,
realization: int,
iter: int,
ensemble: Ensemble,
) -> LoadResult:
parameters_result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "")
response_result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "")
try:
# We only read parameters after the prior, after that, ERT
# handles parameters
if run_arg.itr == 0:
if iter == 0:
parameters_result = await _read_parameters(
run_arg.runpath,
run_arg.iens,
run_arg.ensemble_storage,
)

if parameters_result.status == LoadStatus.LOAD_SUCCESSFUL:
response_result = await _write_responses_to_storage(
run_arg.runpath,
run_arg.iens,
run_arg.ensemble_storage,
run_path,
realization,
ensemble,
)

except Exception as err:
logger.exception(
f"Failed to load results for realization {run_arg.iens}",
exc_info=err,
)
parameters_result = LoadResult(
LoadStatus.LOAD_FAILURE,
"Failed to load results for realization "
f"{run_arg.iens}, failed with: {err}",
)

final_result = parameters_result
if response_result.status != LoadStatus.LOAD_SUCCESSFUL:
final_result = response_result
run_arg.ensemble_storage.set_failure(
run_arg.iens, RealizationStorageState.LOAD_FAILURE, final_result.message
)
elif run_arg.ensemble_storage.has_failure(run_arg.iens):
run_arg.ensemble_storage.unset_failure(run_arg.iens)

return final_result


async def load_run_path_realization(
run_path: str,
realization: int,
ensemble: Ensemble,
) -> LoadResult:
response_result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "")
try:
parameters_result = await _read_parameters(
run_path,
realization,
ensemble,
)

if parameters_result.status == LoadStatus.LOAD_SUCCESSFUL:
response_result = await _write_responses_to_storage(
run_path,
Expand Down
4 changes: 2 additions & 2 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pandas import DataFrame

from ert.analysis import AnalysisEvent, SmootherSnapshot, smoother_update
from ert.callbacks import load_run_path_realization
from ert.callbacks import forward_model_ok
from ert.config import (
EnkfObservationImplementationType,
ErtConfig,
Expand Down Expand Up @@ -50,7 +50,7 @@ def _load_realization_from_run_path(
realization: int,
ensemble: Ensemble,
) -> Tuple[LoadResult, int]:
result = asyncio.run(load_run_path_realization(run_path, realization, ensemble))
result = asyncio.run(forward_model_ok(run_path, realization, 0, ensemble))
return result, realization


Expand Down
7 changes: 6 additions & 1 deletion src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@ async def _verify_checksum(
logger.error(f"Disk synchronization failed for {file_path}")

async def _handle_finished_forward_model(self) -> None:
callback_status, status_msg = await forward_model_ok(self.real.run_arg)
callback_status, status_msg = await forward_model_ok(
run_path=self.real.run_arg.runpath,
realization=self.real.run_arg.iens,
iter=self.real.run_arg.itr,
ensemble=self.real.run_arg.ensemble_storage,
)
if self._message:
self._message = status_msg
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/scheduler/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def test_job_run_sends_expected_events(
realization: Realization,
monkeypatch,
):
async def load_result(_):
async def load_result(**_):
return (forward_model_ok_result, "")

monkeypatch.setattr(ert.scheduler.job, "forward_model_ok", load_result)
Expand Down
31 changes: 14 additions & 17 deletions tests/ert/unit_tests/test_load_forward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ert.config import ErtConfig
from ert.enkf_main import create_run_path
from ert.libres_facade import LibresFacade
from ert.run_arg import create_run_arguments
from ert.storage import open_storage


Expand Down Expand Up @@ -290,9 +289,9 @@ def test_that_the_states_are_set_correctly():
assert new_ensemble.has_data()


@pytest.mark.parametrize("iter", [None, 0, 1, 2, 3])
@pytest.mark.parametrize("itr", [None, 0, 1, 2, 3])
@pytest.mark.usefixtures("use_tmpdir")
def test_loading_from_any_available_iter(storage, run_paths, run_args, iter):
def test_loading_from_any_available_iter(storage, run_paths, run_args, itr):
config_text = dedent(
"""
NUM_REALIZATIONS 1
Expand All @@ -308,23 +307,21 @@ def test_loading_from_any_available_iter(storage, run_paths, run_args, iter):
),
name="prior",
ensemble_size=ert_config.model_config.num_realizations,
iteration=iter if iter is not None else 0,
iteration=itr if itr is not None else 0,
)

run_args = create_run_arguments(
run_paths(ert_config),
[True] * ert_config.model_config.num_realizations,
prior_ensemble,
)
create_run_path(
run_args,
prior_ensemble,
ert_config,
run_paths(ert_config),
)
run_path = Path(
f"simulations/realization-0/iter-{iter if iter is not None else 0}/"
run_args=run_args(ert_config, prior_ensemble),
ensemble=prior_ensemble,
user_config_file=ert_config.user_config_file,
env_vars=ert_config.env_vars,
forward_model_steps=ert_config.forward_model_steps,
substitutions=ert_config.substitutions,
templates=ert_config.ert_templates,
model_config=ert_config.model_config,
runpaths=run_paths(ert_config),
)
run_path = Path(f"simulations/realization-0/iter-{itr if itr is not None else 0}/")
with open(run_path / "response.out", "w", encoding="utf-8") as fout:
fout.write("\n".join(["1", "2", "3"]))
with open(run_path / "response.out_active", "w", encoding="utf-8") as fout:
Expand All @@ -333,7 +330,7 @@ def test_loading_from_any_available_iter(storage, run_paths, run_args, iter):
facade = LibresFacade.from_config_file("config.ert")
run_path_format = str(
Path(
f"simulations/realization-<IENS>/iter-{iter if iter is not None else 0}"
f"simulations/realization-<IENS>/iter-{itr if itr is not None else 0}"
).resolve()
)
facade.load_from_run_path(run_path_format, prior_ensemble, [0])
Expand Down

0 comments on commit c4f70aa

Please sign in to comment.