Skip to content

Commit

Permalink
Save design matrix parameters in sample_prior
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Oct 28, 2024
1 parent ec821e1 commit 54f7a2c
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 20 deletions.
8 changes: 6 additions & 2 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ def merge_with_existing_parameters(
return

new_param_config: List[ParameterConfig] = []
design_keys = self.parameter_configuration[DESIGN_MATRIX_GROUP].getKeyWords()
if isinstance(self.parameter_configuration[DESIGN_MATRIX_GROUP], GenKwConfig):
design_keys = self.parameter_configuration[
DESIGN_MATRIX_GROUP
].getKeyWords()

design_group_added = False
for genkw_group in existing_parameters:
if not isinstance(genkw_group, GenKwConfig):
Expand Down Expand Up @@ -161,7 +165,7 @@ def read_design_matrix(
output_file=None,
transform_function_definitions=transform_function_definitions,
update=False,
disabled=True,
design_matrix=self,
)

design_matrix_df.columns = pd.MultiIndex.from_product(
Expand Down
3 changes: 3 additions & 0 deletions src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

from ert.storage import Ensemble

from .design_matrix import DesignMatrix

_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -73,6 +75,7 @@ class GenKwConfig(ParameterConfig):
transform_function_definitions: List[TransformFunctionDefinition]
forward_init_file: Optional[str] = None
disabled: Optional[bool] = False
design_matrix: Optional[DesignMatrix] = None

def __post_init__(self) -> None:
self.transform_functions: List[TransformFunction] = []
Expand Down
8 changes: 7 additions & 1 deletion src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,13 @@ def sample_prior(
config_node = parameter_configs[parameter]
if config_node.forward_init:
continue
if isinstance(config_node, GenKwConfig) and config_node.disabled:
if (
isinstance(config_node, GenKwConfig)
and config_node.design_matrix is not None
):
save_design_matrix_to_ensemble(
config_node.design_matrix, ensemble, active_realizations
)
continue

for realization_nr in active_realizations:
Expand Down
13 changes: 2 additions & 11 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Experiment, Storage

Expand Down Expand Up @@ -102,16 +102,7 @@ def run_experiment(
np.array(self.active_realizations, dtype=bool),
ensemble=self.ensemble,
)
if (
self.ert_config.analysis_config.design_matrix is not None
and self.ert_config.analysis_config.design_matrix.design_matrix_df
is not None
):
save_design_matrix_to_ensemble(
self.ert_config.analysis_config.design_matrix,
self.ensemble,
np.where(self.active_realizations)[0],
)

sample_prior(
self.ensemble,
np.where(self.active_realizations)[0],
Expand Down
7 changes: 1 addition & 6 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
import xtgeo
from pydantic import BaseModel

from ert.config import (
ExtParamConfig,
Field,
GenKwConfig,
SurfaceConfig,
)
from ert.config import ExtParamConfig, Field, GenKwConfig, SurfaceConfig
from ert.config.parsing.context_values import ContextBoolEncoder
from ert.config.response_config import ResponseConfig
from ert.storage.mode import BaseMode, Mode, require_write
Expand Down

0 comments on commit 54f7a2c

Please sign in to comment.