diff --git a/CHANGELOG.md b/CHANGELOG.md index acc596e68..a6b6bdd8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Replaced unmaintained `mordred` dependency by `mordredcommunity` - `SearchSpace`s now use `ndarray` instead of `Tensor` +### Fixed +- `from_simplex` now efficiently validated in `Campaign.validate_config` + ## [0.8.0] - 2024-02-29 ### Changed - BoTorch dependency bumped to `>=0.9.3` diff --git a/baybe/searchspace/core.py b/baybe/searchspace/core.py index f6673a901..c9e56c8e7 100644 --- a/baybe/searchspace/core.py +++ b/baybe/searchspace/core.py @@ -21,7 +21,10 @@ ) from baybe.parameters.base import ContinuousParameter, DiscreteParameter, Parameter from baybe.searchspace.continuous import SubspaceContinuous -from baybe.searchspace.discrete import SubspaceDiscrete +from baybe.searchspace.discrete import ( + SubspaceDiscrete, + validate_simplex_subspace_from_config, +) from baybe.searchspace.validation import validate_parameters from baybe.serialization import SerialMixin, converter, select_constructor_hook from baybe.telemetry import TELEM_LABELS, telemetry_record_value @@ -300,7 +303,7 @@ def transform( def validate_searchspace_from_config(specs: dict, _) -> None: """Validate the search space specifications while skipping costly creation steps.""" - # For product spaces, only validate the inputs + # Validate product inputs without constructing it if specs.get("constructor", None) == "from_product": parameters = converter.structure(specs["parameters"], List[Parameter]) validate_parameters(parameters) @@ -310,9 +313,18 @@ def validate_searchspace_from_config(specs: dict, _) -> None: constraints = converter.structure(specs["constraints"], List[Constraint]) validate_constraints(constraints, parameters) - # For all other types, validate by construction else: - converter.structure(specs, SearchSpace) + discrete_subspace_specs = specs.get("discrete", {}) + if discrete_subspace_specs.get("constructor", None) == "from_simplex": + # Validate discrete simplex subspace + _validation_converter = converter.copy() + _validation_converter.register_structure_hook( + SubspaceDiscrete, validate_simplex_subspace_from_config + ) + _validation_converter.structure(discrete_subspace_specs, SubspaceDiscrete) + else: + # For all other types, validate by construction + converter.structure(specs, SearchSpace) # Register deserialization hook diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index c3e9be34d..3fd1eb561 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -19,7 +19,7 @@ ) from baybe.parameters.base import DiscreteParameter, Parameter from baybe.parameters.utils import get_parameters_from_dataframe -from baybe.searchspace.validation import validate_parameter_names +from baybe.searchspace.validation import validate_parameter_names, validate_parameters from baybe.serialization import SerialMixin, converter, select_constructor_hook from baybe.utils.boolean import eq_dataframe from baybe.utils.dataframe import ( @@ -197,24 +197,14 @@ def from_product( empty_encoding: bool = False, ) -> SubspaceDiscrete: """See :class:`baybe.searchspace.core.SearchSpace`.""" - # Store the input - if constraints is None: - constraints = [] - else: - # Reorder the constraints according to their execution order - constraints = sorted( - constraints, - key=lambda x: DISCRETE_CONSTRAINTS_FILTERING_ORDER.index(x.__class__), - ) + # Set defaults + constraints = constraints or [] # Create a dataframe representing the experimental search space exp_rep = parameter_cartesian_prod_to_df(parameters) - # Remove entries that violate parameter constraints: - for constraint in (c for c in constraints if c.eval_during_creation): - idxs = constraint.get_invalid(exp_rep) - exp_rep.drop(index=idxs, inplace=True) - exp_rep.reset_index(inplace=True, drop=True) + # Remove entries that violate parameter constraints + _apply_constraint_filter(exp_rep, constraints) return SubspaceDiscrete( parameters=parameters, @@ -354,7 +344,7 @@ def from_simplex( max_values = [max(p.values) for p in simplex_parameters] if not (min(min_values) >= 0.0): raise ValueError( - f"All parameters passed to '{cls.from_simplex.__name__}' " + f"All simplex_parameters passed to '{cls.from_simplex.__name__}' " f"must have non-negative values only." ) @@ -463,10 +453,7 @@ def drop_invalid( exp_rep = pd.merge(exp_rep, product_space, how="cross") # Remove entries that violate parameter constraints: - for constraint in (c for c in constraints if c.eval_during_creation): - idxs = constraint.get_invalid(exp_rep) - exp_rep.drop(index=idxs, inplace=True) - exp_rep.reset_index(inplace=True, drop=True) + _apply_constraint_filter(exp_rep, constraints) return cls( parameters=simplex_parameters + product_parameters, @@ -587,6 +574,27 @@ def transform( return comp_rep +def _apply_constraint_filter(df: pd.DataFrame, constraints: List[DiscreteConstraint]): + """Remove discrete search space entries inplace based on constraints. + + Args: + df: The data in experimental representation to be modified inplace. + constraints: List of discrete constraints. + + """ + # Reorder the constraints according to their execution order + constraints = sorted( + constraints, + key=lambda x: DISCRETE_CONSTRAINTS_FILTERING_ORDER.index(x.__class__), + ) + + # Remove entries that violate parameter constraints: + for constraint in (c for c in constraints if c.eval_during_creation): + idxs = constraint.get_invalid(df) + df.drop(index=idxs, inplace=True) + df.reset_index(inplace=True, drop=True) + + def parameter_cartesian_prod_to_df( parameters: Iterable[Parameter], ) -> pd.DataFrame: @@ -613,5 +621,52 @@ def parameter_cartesian_prod_to_df( return ret +def validate_simplex_subspace_from_config(specs: dict, _) -> None: + """Validate the discrete space while skipping costly creation steps.""" + # Validate product inputs without constructing it + if specs.get("constructor", None) == "from_product": + parameters = converter.structure(specs["parameters"], List[DiscreteParameter]) + validate_parameters(parameters) + + constraints = specs.get("constraints", None) + if constraints: + constraints = converter.structure( + specs["constraints"], List[DiscreteConstraint] + ) + validate_constraints(constraints, parameters) + + # Validate simplex inputs without constructing it + elif specs.get("constructor", None) == "from_simplex": + simplex_parameters = converter.structure( + specs["simplex_parameters"], List[NumericalDiscreteParameter] + ) + + if not all(min(p.values) >= 0.0 for p in simplex_parameters): + raise ValueError( + f"All simplex_parameters passed to " + f"'{SubspaceDiscrete.from_simplex.__name__}' must have non-negative " + f"values only." + ) + + product_parameters = specs.get("product_parameters", None) + if product_parameters: + product_parameters = converter.structure( + specs["product_parameters"], List[DiscreteParameter] + ) + + validate_parameters(simplex_parameters + product_parameters) + + constraints = specs.get("constraints", None) + if constraints: + constraints = converter.structure( + specs["constraints"], List[DiscreteConstraint] + ) + validate_constraints(constraints, simplex_parameters + product_parameters) + + # For all other types, validate by construction + else: + converter.structure(specs, SubspaceDiscrete) + + # Register deserialization hook converter.register_structure_hook(SubspaceDiscrete, select_constructor_hook) diff --git a/tests/conftest.py b/tests/conftest.py index 4d975a84c..c2069dbc4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -655,30 +655,37 @@ def fixture_default_config(): # default campaign object instead of hardcoding it here. This avoids redundant # code and automatically keeps them synced. cfg = """{ - "parameters": [ - { - "type": "NumericalDiscreteParameter", - "name": "Temp_C", - "values": [10, 20, 30, 40] - }, - { - "type": "NumericalDiscreteParameter", - "name": "Concentration", - "values": [0.2, 0.3, 1.4] - }, - __fillin__ + "searchspace": { + "constructor": "from_product", + "parameters": [ + { + "type": "NumericalDiscreteParameter", + "name": "Temp_C", + "values": [10, 20, 30, 40] + }, + { + "type": "NumericalDiscreteParameter", + "name": "Concentration", + "values": [0.2, 0.3, 1.4] + }, + __fillin__ + { + "type": "CategoricalParameter", + "name": "Base", + "values": ["base1", "base2", "base3", "base4", "base5"] + } + ], + "constraints": [] + }, + "objective": { + "mode": "SINGLE", + "targets": [ { - "type": "CategoricalParameter", - "name": "Base", - "values": ["base1", "base2", "base3", "base4", "base5"] + "type": "NumericalTarget", + "name": "Yield", + "mode": "MAX" } - ], - "constraints": [], - "objective": { - "mode": "SINGLE", - "targets": [ - {"name": "Yield", "mode": "MAX"} - ] + ] }, "recommender": { "type": "TwoPhaseMetaRecommender", @@ -716,6 +723,51 @@ def fixture_default_config(): return cfg +@pytest.fixture(name="simplex_config") +def fixture_default_simplex_config(): + """The default simplex config to be used if not specified differently.""" + cfg = """{ + "searchspace": { + "discrete": { + "constructor": "from_simplex", + "simplex_parameters": [ + { + "type": "NumericalDiscreteParameter", + "name": "simplex1", + "values": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + }, + { + "type": "NumericalDiscreteParameter", + "name": "simplex2", + "values": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + } + ], + "product_parameters": [ + { + "type": "CategoricalParameter", + "name": "Granularity", + "values": ["coarse", "medium", "fine"] + } + ], + "max_sum": 1.0, + "boundary_only": true + } + }, + "objective": { + "mode": "SINGLE", + "targets": [ + { + "type": "NumericalTarget", + "name": "Yield", + "mode": "MAX" + } + ] + } + }""" + + return cfg + + @pytest.fixture(name="onnx_str") def fixture_default_onnx_str() -> Union[bytes, None]: """The default ONNX model string to be used if not specified differently.""" diff --git a/tests/serialization/test_campaign_serialization.py b/tests/serialization/test_campaign_serialization.py index 18e68a1b5..105bb37ee 100644 --- a/tests/serialization/test_campaign_serialization.py +++ b/tests/serialization/test_campaign_serialization.py @@ -20,11 +20,21 @@ def test_campaign_serialization(campaign): assert campaign == campaign2 -def test_valid_config(config): +def test_valid_product_config(config): Campaign.validate_config(config) -def test_invalid_config(config): +def test_invalid_product_config(config): config = config.replace("CategoricalParameter", "CatParam") with pytest.raises(ClassValidationError): Campaign.validate_config(config) + + +def test_valid_simplex_config(simplex_config): + Campaign.validate_config(simplex_config) + + +def test_invalid_simplex_config(simplex_config): + simplex_config = simplex_config.replace("0.0, ", "-1.0, 0.0, ") + with pytest.raises(ClassValidationError): + Campaign.validate_config(simplex_config)