Skip to content

Commit

Permalink
Rename load_priors -> save_design_matrix_ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Oct 17, 2024
1 parent be55192 commit 457efaa
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
25 changes: 9 additions & 16 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,13 @@
import xarray as xr
from numpy.random import SeedSequence

from .config import (
DesignMatrix,
ExtParamConfig,
Field,
GenKwConfig,
ParameterConfig,
SurfaceConfig,
)
from .config import ExtParamConfig, Field, GenKwConfig, ParameterConfig, SurfaceConfig
from .run_arg import RunArg
from .runpaths import Runpaths

if TYPE_CHECKING:
import pandas as pd

from .config import ErtConfig
from .storage import Ensemble

Expand Down Expand Up @@ -150,16 +145,14 @@ def _seed_sequence(seed: Optional[int]) -> int:
return int_seed


def load_prior(
ensemble: Ensemble, active_realizations: Iterable[int], design_matrix: DesignMatrix
def save_design_matrix_to_ensemble(
design_matrix_df: pd.DataFrame,
ensemble: Ensemble,
active_realizations: Iterable[int],
) -> None:
assert design_matrix.parameter_configuration is not None
assert (
design_matrix.design_matrix_df is not None
and not design_matrix.design_matrix_df.empty
)
assert not design_matrix_df.empty
for realization_nr in active_realizations:
row = design_matrix.design_matrix_df.loc[realization_nr]["DESIGN_MATRIX"]
row = design_matrix_df.loc[realization_nr]["DESIGN_MATRIX"]
ds = xr.Dataset(
{
"values": ("names", list(row.values)),
Expand Down
12 changes: 8 additions & 4 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from ert.enkf_main import load_prior, sample_prior
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Experiment, Storage

Expand Down Expand Up @@ -115,11 +115,15 @@ def run_experiment(
np.array(self.active_realizations, dtype=bool),
ensemble=self.ensemble,
)
if self.ert_config.analysis_config.design_matrix is not None:
load_prior(
if (
self.ert_config.analysis_config.design_matrix is not None
and self.ert_config.analysis_config.design_matrix.design_matrix_df
is not None
):
save_design_matrix_to_ensemble(
self.ert_config.analysis_config.design_matrix.design_matrix_df,
self.ensemble,
np.where(self.active_realizations)[0],
self.ert_config.analysis_config.design_matrix,
)
else:
sample_prior(
Expand Down

0 comments on commit 457efaa

Please sign in to comment.