diff --git a/CHANGELOG.md b/CHANGELOG.md index 2969cf922..4f84d4c58 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 โ€“ `Campaign.toggle_discrete_candidates` to dynamically in-/exclude discrete candidates - `DiscreteConstraint.get_valid` to conveniently access valid candidates - Functionality for persisting benchmarking results on S3 from a manual pipeline run +- `ContinuousCardinalityConstraint` is now compatible with `BotorchRecommender` +- Warning `MinimumCardinalityViolatedWarning` is triggered when any minimum + cardinality is violated in `BotorchRecommender` +- Attribute `max_n_subspaces` to `BotorchRecommender`, allowing to control + optimization behavior in the presence of multiple subspaces +- Utilities `inactive_parameter_combinations` and`n_inactive_parameter_combinations` + in both `ContinuousCardinalityConstraint`and `SubspaceContinuous` +- Attribute `relative_threshold` and method `get_threshold` to + `ContinuousCardinalityConstraint` +- Utilities `count_zeros` and `is_cardinality_fulfilled` ### Changed - `SubstanceParameter` encodings are now computed exclusively with the diff --git a/README.md b/README.md index 6a339f902..59cde0810 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ Besides functionality to perform a typical recommend-measure loop, BayBE's highl - ๐ŸŽญ Hybrid (mixed continuous and discrete) spaces - ๐Ÿš€ Transfer learning: Mix data from multiple campaigns and accelerate optimization - ๐ŸŽฐ Bandit models: Efficiently find the best among many options in noisy environments (e.g. A/B Testing) +- ๐Ÿ”ข Cardinality constraints: Control the number of active factors in your design - ๐ŸŒŽ Distributed workflows: Run campaigns asynchronously with pending experiments - ๐ŸŽ“ Active learning: Perform smart data acquisition campaigns - โš™๏ธ Custom surrogate models: Enhance your predictions through mechanistic understanding diff --git a/baybe/constraints/continuous.py b/baybe/constraints/continuous.py index d5d5be47e..ce70d00c7 100644 --- a/baybe/constraints/continuous.py +++ b/baybe/constraints/continuous.py @@ -4,11 +4,13 @@ import gc import math -from collections.abc import Collection, Sequence +from collections.abc import Collection, Iterator, Sequence +from itertools import combinations +from math import comb from typing import TYPE_CHECKING, Any import numpy as np -from attr.validators import in_ +from attr.validators import gt, in_, lt from attrs import define, field from baybe.constraints.base import ( @@ -17,6 +19,7 @@ ContinuousNonlinearConstraint, ) from baybe.parameters import NumericalContinuousParameter +from baybe.utils.interval import Interval from baybe.utils.numerical import DTypeFloatNumpy from baybe.utils.validation import finite_float @@ -138,6 +141,31 @@ class ContinuousCardinalityConstraint( ): """Class for continuous cardinality constraints.""" + relative_threshold: float = field( + default=1e-2, converter=float, validator=[gt(0.0), lt(1.0)] + ) + """A relative threshold for determining if the value is considered zero.""" + + @property + def n_inactive_parameter_combinations(self) -> int: + """The number of possible inactive parameter combinations.""" + return sum( + comb(len(self.parameters), n_inactive_parameters) + for n_inactive_parameters in self._inactive_set_sizes() + ) + + def _inactive_set_sizes(self) -> range: + """Get all possible sizes of inactive parameter sets.""" + return range( + len(self.parameters) - self.max_cardinality, + len(self.parameters) - self.min_cardinality + 1, + ) + + def inactive_parameter_combinations(self) -> Iterator[frozenset[str]]: + """Get an iterator over all possible combinations of inactive parameters.""" + for n_inactive_parameters in self._inactive_set_sizes(): + yield from combinations(self.parameters, n_inactive_parameters) + def sample_inactive_parameters(self, batch_size: int = 1) -> list[set[str]]: """Sample sets of inactive parameters according to the cardinality constraints. @@ -176,6 +204,48 @@ def sample_inactive_parameters(self, batch_size: int = 1) -> list[set[str]]: return inactive_params + def get_threshold(self, parameter: NumericalContinuousParameter) -> Interval: + """Get the threshold values of a parameter. + + This method calculates the thresholds based on the parameter's bounds + and the relative threshold. + + Note: + Thresholds (lower, upper) are defined below: + * If lower < 0 and upper > 0, any value v with lower < v < upper are treated + zero; + * If lower = 0 and upper > 0, any value v with lower <= v < upper are + treated zero; + * If lower < 0 and upper = 0, any value v with lower < v <= upper are + treated zero. + + Args: + parameter: The parameter object. + + Returns: + The lower and upper thresholds. + + Raises: + ValueError: when parameter_name is not present in parameter list of this + constraint. + ValueError: when parameter bounds do not cover zero. + """ + if parameter.name not in self.parameters: + raise ValueError( + f"The given parameter with name: {parameter.name} cannot " + f"be found in the parameter list: {self.parameters}." + ) + if parameter.bounds.contains(0.0): + raise ValueError( + f"The bounds of the given parameter must cover zero but its bounds " + f"are ({parameter.bounds.lower}, {parameter.bounds.upper})." + ) + + return Interval( + lower=self.relative_threshold * parameter.bounds.lower, + upper=self.relative_threshold * parameter.bounds.upper, + ) + # Collect leftover original slotted classes processed by `attrs.define` gc.collect() diff --git a/baybe/constraints/validation.py b/baybe/constraints/validation.py index f9c34f9aa..8c80a172f 100644 --- a/baybe/constraints/validation.py +++ b/baybe/constraints/validation.py @@ -8,8 +8,14 @@ from baybe.constraints.discrete import ( DiscreteDependenciesConstraint, ) +from baybe.parameters import NumericalContinuousParameter from baybe.parameters.base import Parameter +try: # For python < 3.11, use the exceptiongroup backport + ExceptionGroup +except NameError: + from exceptiongroup import ExceptionGroup + def validate_constraints( # noqa: DOC101, DOC103 constraints: Collection[Constraint], parameters: Collection[Parameter] @@ -26,6 +32,8 @@ def validate_constraints( # noqa: DOC101, DOC103 ValueError: If any discrete constraint includes a continuous parameter. ValueError: If any discrete constraint that is valid only for numerical discrete parameters includes non-numerical discrete parameters. + ValueError: If any parameter affected by a cardinality constraint does + not include zero. """ if sum(isinstance(itm, DiscreteDependenciesConstraint) for itm in constraints) > 1: raise ValueError( @@ -41,6 +49,9 @@ def validate_constraints( # noqa: DOC101, DOC103 param_names_discrete = [p.name for p in parameters if p.is_discrete] param_names_continuous = [p.name for p in parameters if p.is_continuous] param_names_non_numerical = [p.name for p in parameters if not p.is_numerical] + params_continuous: list[NumericalContinuousParameter] = [ + p for p in parameters if isinstance(p, NumericalContinuousParameter) + ] for constraint in constraints: if not all(p in param_names_all for p in constraint.parameters): @@ -78,6 +89,11 @@ def validate_constraints( # noqa: DOC101, DOC103 f"Parameter list of the affected constraint: {constraint.parameters}." ) + if isinstance(constraint, ContinuousCardinalityConstraint): + validate_cardinality_constraint_parameter_bounds( + constraint, params_continuous + ) + def validate_cardinality_constraints_are_nonoverlapping( constraints: Collection[ContinuousCardinalityConstraint], @@ -98,3 +114,44 @@ def validate_cardinality_constraints_are_nonoverlapping( f"cannot share the same parameters. Found the following overlapping " f"parameter sets: {s1}, {s2}." ) + + +def validate_cardinality_constraint_parameter_bounds( + constraint: ContinuousCardinalityConstraint, + parameters: Collection[NumericalContinuousParameter], +) -> None: + """Validate that all parameters of a continuous cardinality constraint include zero. + + Args: + constraint: A continuous cardinality constraint. + parameters: A collection of parameters, including those affected by the + constraint. + + Raises: + ValueError: If one of the affected parameters does not include zero. + ExceptionGroup: If several of the affected parameters do not include zero. + """ + exceptions = [] + for name in constraint.parameters: + try: + parameter = next(p for p in parameters if p.name == name) + except StopIteration as ex: + raise ValueError( + f"The parameter '{name}' referenced by the constraint is not contained " + f"in the given collection of parameters." + ) from ex + + if not parameter.is_in_range(0.0): + exceptions.append( + ValueError( + f"The bounds of all parameters affected by a constraint of type " + f"'{ContinuousCardinalityConstraint.__name__}' must include zero, " + f"but the bounds of parameter '{name}' are: " + f"{parameter.bounds.to_tuple()}" + ) + ) + + if exceptions: + if len(exceptions) == 1: + raise exceptions[0] + raise ExceptionGroup("Invalid parameter bounds", exceptions) diff --git a/baybe/exceptions.py b/baybe/exceptions.py index 661f61a97..3fd5aaf33 100644 --- a/baybe/exceptions.py +++ b/baybe/exceptions.py @@ -9,6 +9,10 @@ class UnusedObjectWarning(UserWarning): """ +class MinimumCardinalityViolatedWarning(UserWarning): + """Minimum cardinality constraints are violated.""" + + ##### Exceptions ##### diff --git a/baybe/parameters/numerical.py b/baybe/parameters/numerical.py index 5da76060c..e58afa6af 100644 --- a/baybe/parameters/numerical.py +++ b/baybe/parameters/numerical.py @@ -136,7 +136,7 @@ def is_in_range(self, item: float) -> bool: @override @property - def comp_rep_columns(self) -> tuple[str, ...]: + def comp_rep_columns(self) -> tuple[str]: return (self.name,) @override @@ -150,5 +150,38 @@ def summary(self) -> dict: return param_dict +@define(frozen=True, slots=False) +class _FixedNumericalContinuousParameter(ContinuousParameter): + """Parameter class for fixed numerical parameters.""" + + is_numeric: ClassVar[bool] = True + # See base class. + + value: float = field(converter=float) + """The fixed value of the parameter.""" + + @property + def bounds(self) -> Interval: + """The value of the parameter as a degenerate interval.""" + return Interval(self.value, self.value) + + @override + def is_in_range(self, item: float) -> bool: + return item == self.value + + @override + @property + def comp_rep_columns(self) -> tuple[str]: + return (self.name,) + + @override + def summary(self) -> dict: + return dict( + Name=self.name, + Type=self.__class__.__name__, + Value=self.value, + ) + + # Collect leftover original slotted classes processed by `attrs.define` gc.collect() diff --git a/baybe/parameters/utils.py b/baybe/parameters/utils.py index ec33ce455..b6dc18c38 100644 --- a/baybe/parameters/utils.py +++ b/baybe/parameters/utils.py @@ -4,8 +4,13 @@ from typing import Any, TypeVar import pandas as pd +from attrs import evolve from baybe.parameters.base import Parameter +from baybe.parameters.numerical import ( + NumericalContinuousParameter, +) +from baybe.utils.interval import Interval _TParameter = TypeVar("_TParameter", bound=Parameter) @@ -87,3 +92,79 @@ def get_parameters_from_dataframe( def sort_parameters(parameters: Collection[Parameter]) -> tuple[Parameter, ...]: """Sort parameters alphabetically by their names.""" return tuple(sorted(parameters, key=lambda p: p.name)) + + +def activate_parameter( + parameter: NumericalContinuousParameter, + thresholds: Interval, +) -> NumericalContinuousParameter: + """Activates a given parameter by moving its bounds away from zero. + + Important: + Parameters whose ranges include zero but whose bounds do not overlap with the + inactive range (i.e. parameters that contain the value zero far from their + boundary values) remain unchanged, because the corresponding activated parameter + would no longer have a continuous value range. + + Args: + parameter: The parameter to be activated. + thresholds: The thresholds of the inactive range of the parameter. + + Returns: + A copy of the parameter with adjusted bounds. + + Raises: + ValueError: If the threshold does not cover zero. + ValueError: If the parameter cannot be activated since both its bounds are + in the inactive range. + """ + lower_bound = parameter.bounds.lower + upper_bound = parameter.bounds.upper + + if not thresholds.contains(0.0): + raise ValueError( + f"The thresholds must cover zero but ({thresholds.lower}, " + f"{thresholds.upper}) is given." + ) + + if not parameter.bounds.contains(0.0): + raise ValueError( + f"The parameter bounds must cover zero but " + f"({parameter.bounds.lower}, {parameter.bounds.upper}) is " + f"given." + ) + + # Note that the definition on the boundary (lower/upper threshold) is vague. + # The value on the lower/upper boundary is determined as within inactive_range; + # while an activated parameter may take this boundary value (lower/upper + # threshold). We allow the misuse of boundary in the "in_inactive_range" and it + # is just an utils for checking condition. Ultimately, the "key" threshold + # boundary appears as a bound of the activated parameter and this is compatible + # with the thresholds defined in ContinuousCardinalityConstraint, as long as the + # "key" threshold boundary is not zero. The "key" threshold boundary is always + # non-zero when the thresholds are inferred from the bounds of this parameter. + + def in_inactive_range(x: float) -> bool: + """Return true when x is within the inactive range.""" + return thresholds.lower <= x <= thresholds.upper + + # When both bounds in inactive range. + if in_inactive_range(lower_bound) and in_inactive_range(upper_bound): + raise ValueError( + f"Parameter '{parameter.name}' cannot be set active since its " + f"bounds {parameter.bounds.to_tuple()} are entirely contained in the " + f"inactive range ({thresholds.lower}, {thresholds.upper})." + ) + + # When the upper bound is in inactive range, move it to the lower threshold of the + # inactive region. + if lower_bound < thresholds.lower and in_inactive_range(upper_bound): + return evolve(parameter, bounds=(lower_bound, thresholds.lower)) + + # When the lower bound is in inactive range, move it to the upper threshold of + # the inactive region + if upper_bound > thresholds.upper and in_inactive_range(lower_bound): + return evolve(parameter, bounds=(thresholds.upper, upper_bound)) + + # Both bounds separated from inactive range + return parameter diff --git a/baybe/recommenders/pure/bayesian/botorch.py b/baybe/recommenders/pure/bayesian/botorch.py index 3b1c0de96..4f5a61028 100644 --- a/baybe/recommenders/pure/bayesian/botorch.py +++ b/baybe/recommenders/pure/bayesian/botorch.py @@ -1,20 +1,28 @@ """Botorch recommender.""" +from __future__ import annotations + import gc import math -from typing import Any, ClassVar +import warnings +from collections.abc import Collection, Iterable +from typing import TYPE_CHECKING, Any, ClassVar +import numpy as np import pandas as pd from attrs import define, field from attrs.converters import optional as optional_c -from attrs.validators import gt, instance_of +from attrs.validators import ge, gt, instance_of from typing_extensions import override from baybe.acquisition.acqfs import qThompsonSampling +from baybe.constraints import ContinuousCardinalityConstraint from baybe.exceptions import ( IncompatibilityError, IncompatibleAcquisitionFunctionError, + MinimumCardinalityViolatedWarning, ) +from baybe.parameters.numerical import _FixedNumericalContinuousParameter from baybe.recommenders.pure.bayesian.base import BayesianRecommender from baybe.searchspace import ( SearchSpace, @@ -22,6 +30,7 @@ SubspaceContinuous, SubspaceDiscrete, ) +from baybe.utils.cardinality_constraints import is_cardinality_fulfilled from baybe.utils.dataframe import to_tensor from baybe.utils.plotting import to_string from baybe.utils.sampling_algorithms import ( @@ -29,6 +38,9 @@ sample_numerical_df, ) +if TYPE_CHECKING: + from torch import Tensor + @define(kw_only=True) class BotorchRecommender(BayesianRecommender): @@ -76,6 +88,13 @@ class BotorchRecommender(BayesianRecommender): optimization. **Does not affect purely discrete optimization**. """ + max_n_subspaces: int = field(default=10, validator=[instance_of(int), ge(1)]) + """Threshold defining the maximum number of subspaces to consider for exhaustive + search in the presence of cardinality constraints. If the combinatorial number of + groupings into active and inactive parameters dictated by the constraints is greater + than this number, that many randomly selected combinations are selected for + optimization.""" + @sampling_percentage.validator def _validate_percentage( # noqa: DOC101, DOC103 self, _: Any, value: float @@ -90,6 +109,24 @@ def _validate_percentage( # noqa: DOC101, DOC103 f"Hybrid sampling percentage needs to be between 0 and 1 but is {value}" ) + @override + def __str__(self) -> str: + fields = [ + to_string("Surrogate", self._surrogate_model), + to_string( + "Acquisition function", self.acquisition_function, single_line=True + ), + to_string("Compatibility", self.compatibility, single_line=True), + to_string( + "Sequential continuous", self.sequential_continuous, single_line=True + ), + to_string("Hybrid sampler", self.hybrid_sampler, single_line=True), + to_string( + "Sampling percentage", self.sampling_percentage, single_line=True + ), + ] + return to_string(self.__class__.__name__, *fields) + @override def _recommend_discrete( self, @@ -168,38 +205,159 @@ def _recommend_continuous( Returns: A dataframe containing the recommendations as individual rows. """ - # For batch size > 1, this optimizer needs a MC acquisition function if batch_size > 1 and not self.acquisition_function.is_mc: raise IncompatibleAcquisitionFunctionError( f"The '{self.__class__.__name__}' only works with Monte Carlo " f"acquisition functions for batch sizes > 1." ) + points, _ = self._recommend_continuous_torch(subspace_continuous, batch_size) + + return pd.DataFrame(points, columns=subspace_continuous.parameter_names) + + def _recommend_continuous_torch( + self, subspace_continuous: SubspaceContinuous, batch_size: int + ) -> tuple[Tensor, Tensor]: + """Dispatcher selecting continuous optimization routine.""" + if subspace_continuous.constraints_cardinality: + return self._recommend_continuous_with_cardinality_constraints( + subspace_continuous, batch_size + ) + else: + return self._recommend_continuous_without_cardinality_constraints( + subspace_continuous, batch_size + ) + + def _recommend_continuous_with_cardinality_constraints( + self, + subspace_continuous: SubspaceContinuous, + batch_size: int, + ) -> tuple[Tensor, Tensor]: + """Recommend from a continuous search space with cardinality constraints. + + This is achieved by consideringย the individual restricted subspaces that can be + obtained by splitting the parameters into sets of active and inactive + parameters, according to what is allowed by the cardinality constraints. In each + of these spaces, the in-/activity assignment is fixed, so that the cardinality + constraints can be removed and a regular optimization can be performed. The + recommendation is then constructed from the combined optimization results of the + unconstrained spaces. + + Args: + subspace_continuous: The continuous subspace from which to generate + recommendations. + batch_size: The size of the recommendation batch. + + Returns: + The recommendations and corresponding acquisition values. + + Raises: + ValueError: If the continuous search space has no cardinality constraints. + """ + if not subspace_continuous.constraints_cardinality: + raise ValueError( + f"'{self._recommend_continuous_with_cardinality_constraints.__name__}' " + f"expects a subspace with constraints of type " + f"'{ContinuousCardinalityConstraint.__name__}'. " + ) + + # Determine search scope based on number of inactive parameter combinations + exhaustive_search = ( + subspace_continuous.n_inactive_parameter_combinations + <= self.max_n_subspaces + ) + iterator: Iterable[Collection[str]] + if exhaustive_search: + # If manageable, evaluate all combinations of inactive parameters + iterator = subspace_continuous.inactive_parameter_combinations() + else: + # Otherwise, draw a random subset of inactive parameter combinations + iterator = subspace_continuous._sample_inactive_parameters( + self.max_n_subspaces + ) + + # Create iterable of subspaces to be optimized + subspaces = ( + ( + subspace_continuous._enforce_cardinality_constraints_via_assignment( + inactive_parameters + ) + ) + for inactive_parameters in iterator + ) + + points, acqf_value = self._optimize_continuous_subspaces(subspaces, batch_size) + + # Check if any minimum cardinality constraints are violated + if not is_cardinality_fulfilled( + subspace_continuous, + pd.DataFrame(points, columns=subspace_continuous.parameter_names), + "min", + ): + warnings.warn( + "At least one minimum cardinality constraint is violated.", + MinimumCardinalityViolatedWarning, + ) + + return points, acqf_value + + def _recommend_continuous_without_cardinality_constraints( + self, + subspace_continuous: SubspaceContinuous, + batch_size: int, + ) -> tuple[Tensor, Tensor]: + """Recommend from a continuous search space without cardinality constraints. + + Args: + subspace_continuous: The continuous subspace from which to generate + recommendations. + batch_size: The size of the recommendation batch. + + Returns: + The recommendations and corresponding acquisition values. + + Raises: + ValueError: If the continuous search space has cardinality constraints. + """ import torch from botorch.optim import optimize_acqf - points, _ = optimize_acqf( + if subspace_continuous.constraints_cardinality: + raise ValueError( + f"'{self._recommend_continuous_without_cardinality_constraints.__name__}' " # noqa: E501 + f"expects a subspace without constraints of type " + f"'{ContinuousCardinalityConstraint.__name__}'. " + ) + + fixed_parameters = { + idx: p.value + for (idx, p) in enumerate(subspace_continuous.parameters) + if isinstance(p, _FixedNumericalContinuousParameter) + } + + points, acqf_values = optimize_acqf( acq_function=self._botorch_acqf, bounds=torch.from_numpy(subspace_continuous.comp_rep_bounds.values), q=batch_size, num_restarts=self.n_restarts, raw_samples=self.n_raw_samples, + # TODO: https://github.com/pytorch/botorch/issues/2042 + fixed_features=fixed_parameters or None, + # TODO: https://github.com/pytorch/botorch/issues/2042 equality_constraints=[ c.to_botorch(subspace_continuous.parameters) for c in subspace_continuous.constraints_lin_eq ] - or None, # TODO: https://github.com/pytorch/botorch/issues/2042 + or None, + # TODO: https://github.com/pytorch/botorch/issues/2042 inequality_constraints=[ c.to_botorch(subspace_continuous.parameters) for c in subspace_continuous.constraints_lin_ineq ] - or None, # TODO: https://github.com/pytorch/botorch/issues/2042 + or None, sequential=self.sequential_continuous, ) - - # Return optimized points as dataframe - rec = pd.DataFrame(points, columns=subspace_continuous.parameter_names) - return rec + return points, acqf_values @override def _recommend_hybrid( @@ -314,23 +472,47 @@ def _recommend_hybrid( return rec_exp - @override - def __str__(self) -> str: - fields = [ - to_string("Surrogate", self._surrogate_model), - to_string( - "Acquisition function", self.acquisition_function, single_line=True - ), - to_string("Compatibility", self.compatibility, single_line=True), - to_string( - "Sequential continuous", self.sequential_continuous, single_line=True - ), - to_string("Hybrid sampler", self.hybrid_sampler, single_line=True), - to_string( - "Sampling percentage", self.sampling_percentage, single_line=True - ), - ] - return to_string(self.__class__.__name__, *fields) + def _optimize_continuous_subspaces( + self, subspaces: Iterable[SubspaceContinuous], batch_size: int + ) -> tuple[Tensor, Tensor]: + """Find the optimum candidates from multiple continuous subspaces. + + **Important**: A subspace without a feasible solution will be ignored + silently, and no warning will be raised. This design is intentional to + accommodate recommendations with cardinality constraints. Please be mindful + of this behavior when invoking this method. + + Args: + subspaces: The subspaces to consider for the optimization. + batch_size: The number of points to be recommended. + + Returns: + The batch of candidates and the corresponding acquisition value. + """ + acqf_values_all: list[Tensor] = [] + points_all: list[Tensor] = [] + + for subspace in subspaces: + try: + # Optimize the acquisition function + p, acqf = self._recommend_continuous_torch(subspace, batch_size) + + # Append optimization results + points_all.append(p) + acqf_values_all.append(acqf) + + # TODO: Replace ValueError with customized erorr. See + # https://github.com/pytorch/botorch/pull/2652 + # The optimization problem may be infeasible in certain subspaces + except ValueError: + pass + + # Find the best option f + best_idx = np.argmax(acqf_values_all) + points = points_all[best_idx] + acqf_value = acqf_values_all[best_idx] + + return points, acqf_value # Collect leftover original slotted classes processed by `attrs.define` diff --git a/baybe/searchspace/continuous.py b/baybe/searchspace/continuous.py index 5fedf3f4e..70493e7e4 100644 --- a/baybe/searchspace/continuous.py +++ b/baybe/searchspace/continuous.py @@ -3,14 +3,15 @@ from __future__ import annotations import gc +import math import warnings -from collections.abc import Collection, Sequence -from itertools import chain +from collections.abc import Collection, Iterable, Sequence +from itertools import chain, product from typing import TYPE_CHECKING, Any, cast import numpy as np import pandas as pd -from attrs import define, field, fields +from attrs import define, evolve, field, fields from typing_extensions import override from baybe.constraints import ( @@ -19,11 +20,17 @@ ) from baybe.constraints.base import ContinuousConstraint, ContinuousNonlinearConstraint from baybe.constraints.validation import ( + validate_cardinality_constraint_parameter_bounds, validate_cardinality_constraints_are_nonoverlapping, ) from baybe.parameters import NumericalContinuousParameter from baybe.parameters.base import ContinuousParameter -from baybe.parameters.utils import get_parameters_from_dataframe, sort_parameters +from baybe.parameters.numerical import _FixedNumericalContinuousParameter +from baybe.parameters.utils import ( + activate_parameter, + get_parameters_from_dataframe, + sort_parameters, +) from baybe.searchspace.validation import ( validate_parameter_names, ) @@ -134,6 +141,23 @@ def _validate_constraints_lin_ineq( f"the 'operator' for all list items should be '>=' or '<='." ) + @property + def n_inactive_parameter_combinations(self) -> int: + """The number of possible inactive parameter combinations.""" + return math.prod( + c.n_inactive_parameter_combinations for c in self.constraints_cardinality + ) + + def inactive_parameter_combinations(self) -> Iterable[frozenset[str]]: + """Get an iterator over all possible combinations of inactive parameters.""" + for combination in product( + *[ + con.inactive_parameter_combinations() + for con in self.constraints_cardinality + ] + ): + yield frozenset(chain(*combination)) + @constraints_nonlin.validator def _validate_constraints_nonlin(self, _, __) -> None: """Validate nonlinear constraints.""" @@ -142,6 +166,9 @@ def _validate_constraints_nonlin(self, _, __) -> None: self.constraints_cardinality ) + for con in self.constraints_cardinality: + validate_cardinality_constraint_parameter_bounds(con, self.parameters) + def to_searchspace(self) -> SearchSpace: """Turn the subspace into a search space with no discrete part.""" from baybe.searchspace.core import SearchSpace @@ -277,6 +304,12 @@ def comp_rep_columns(self) -> tuple[str, ...]: """The columns spanning the computational representation.""" return tuple(chain.from_iterable(p.comp_rep_columns for p in self.parameters)) + @property + def parameter_names_in_cardinality_constraints(self) -> tuple[str, ...]: + """The names of all parameters affected by cardinality constraints.""" + names_per_constraint = (c.parameters for c in self.constraints_cardinality) + return tuple(chain(*names_per_constraint)) + @property def comp_rep_bounds(self) -> pd.DataFrame: """The minimum and maximum values of the computational representation.""" @@ -309,6 +342,49 @@ def _drop_parameters(self, parameter_names: Collection[str]) -> SubspaceContinuo ], ) + def _enforce_cardinality_constraints_via_assignment( + self, + inactive_parameter_names: Collection[str], + ) -> SubspaceContinuous: + """Create a copy of the subspace with fixed inactive parameters. + + The returned subspace requires no cardinality constraints since โ€“ for the + given separation of parameter into active an inactive sets โ€“ the + cardinality constraints are implemented by fixing the inactive parameters to + zero and bounding the active parameters away from zero. + + Args: + inactive_parameter_names: The names of the parameter to be inactivated. + + Returns: + A new subspace with fixed inactive parameters and no cardinality + constraints. + """ + # Extract active parameters involved in cardinality constraints + active_parameter_names = set( + self.parameter_names_in_cardinality_constraints + ).difference(inactive_parameter_names) + + # Adjust parameters depending on their in-/activity assignment + adjusted_parameters: list[ContinuousParameter] = [] + p_adjusted: ContinuousParameter + for p in self.parameters: + if p.name in inactive_parameter_names: + p_adjusted = _FixedNumericalContinuousParameter(name=p.name, value=0.0) + elif p.name in active_parameter_names: + # cardinality constraint object containing the current parameter + cardinality_constraint_with_p = [ + c for c in self.constraints_cardinality if p.name in c.parameters + ][0] + p_adjusted = activate_parameter( + p, cardinality_constraint_with_p.get_threshold(p) + ) + else: + p_adjusted = p + adjusted_parameters.append(p_adjusted) + + return evolve(self, parameters=adjusted_parameters, constraints_nonlin=()) + def transform( self, df: pd.DataFrame | None = None, @@ -453,11 +529,20 @@ def _sample_from_polytope_with_cardinality_constraints( # Randomly set some parameters inactive inactive_params_sample = self._sample_inactive_parameters(1)[0] - # Remove the inactive parameters from the search space - subspace_without_cardinality_constraint = self._drop_parameters( - inactive_params_sample + # Remove the inactive parameters from the search space. In the first + # step, the active parameters get activated and inactive parameters are + # fixed to zero. The first step helps ensure active parameters stay + # non-zero, especially when one boundary is zero. The second step is + # optional and it helps reduce the parameter space with certain + # computational cost. + subspace_without_cardinality_constraint = ( + self._enforce_cardinality_constraints_via_assignment( + inactive_params_sample + )._drop_parameters(inactive_params_sample) ) + # TODO: Replace ValueError with customized erorr. See + # https://github.com/pytorch/botorch/pull/2652 # Sample from the reduced space try: sample = subspace_without_cardinality_constraint.sample_uniform(1) diff --git a/baybe/utils/cardinality_constraints.py b/baybe/utils/cardinality_constraints.py new file mode 100644 index 000000000..f93d0ec5c --- /dev/null +++ b/baybe/utils/cardinality_constraints.py @@ -0,0 +1,113 @@ +"""Utilities related to cardinality constraints.""" + +from typing import Literal + +import numpy as np +import pandas as pd + +from baybe.searchspace import SubspaceContinuous +from baybe.utils.interval import Interval + + +def count_zeros(thresholds: tuple[Interval, ...], points: pd.DataFrame) -> np.ndarray: + """Return the counts of zeros in the recommendations. + + Args: + thresholds: A list of thresholds according to which the counts of zeros + in the recommendations should be calculated. + points: The recommendations of the parameter objects. + + Returns: + The counts of zero parameters in the recommendations. + + Raises: + ValueError: If the number of thresholds differs from the number of + parameters in points. + """ + if len(thresholds) != len(points.columns): + raise ValueError( + f"The size of thresholds ({len(thresholds)}) must be the same as the " + f"number of parameters ({len(points.columns)}) in points." + ) + # Get the lower/upper thresholds for determining zeros/non-zeros + lower_thresholds = np.array([threshold.lower for threshold in thresholds]) + lower_thresholds = np.broadcast_to(lower_thresholds, points.shape) + + upper_thresholds = np.array([threshold.upper for threshold in thresholds]) + upper_thresholds = np.broadcast_to(upper_thresholds, points.shape) + + # Boolean values indicating whether the candidates are treated zeros: True for zero + zero_flags = (points > lower_thresholds) & (points < upper_thresholds) + + # Correct the comparison on the special boundary: zero. This step is needed + # because when the lower_threshold = 0, a value v with lower_threshold <= v < + # upper_threshold should be treated zero. + zero_flags = (points == 0.0) | zero_flags + + return np.sum(zero_flags, axis=1) + + +def is_cardinality_fulfilled( + subspace_continuous: SubspaceContinuous, + batch: pd.DataFrame, + type_cardinality: Literal["min", "max"], +) -> bool: + """Check whether all minimum (or maximum) cardinality constraints are fulfilled. + + Args: + subspace_continuous: The continuous subspace from which candidates are + generated. + batch: The recommended batch + type_cardinality: "min" or "max". "min" indicates all minimum cardinality + constraints will be checked; "max" for all maximum cardinality constraints. + + Returns: + Return "True" if all minimum (or maximum) cardinality constraints are + fulfilled; "False" otherwise. + + Raises: + ValueError: If type_cardinality is neither "min" nor "max". + """ + if type_cardinality not in ["min", "max"]: + raise ValueError( + f"Unknown type of cardinality. Only support min or max but " + f"{type_cardinality=} is given." + ) + + if len(subspace_continuous.constraints_cardinality) == 0: + return True + + for c in subspace_continuous.constraints_cardinality: + # No need to check the redundant cardinality constraints that are + # - min_cardinality = 0 + # - max_cardinality = len(parameters) + if (c.min_cardinality == 0) and type_cardinality == "min": + continue + + if (c.max_cardinality == len(c.parameters)) and type_cardinality == "max": + continue + + # Batch of parameters that are related to cardinality constraint + batch_related_to_c = batch[c.parameters] + + # Parameters related to cardinality constraint + parameters_in_c = subspace_continuous.get_parameters_by_name(c.parameters) + + # Thresholds of parameters that are related to the cardinality constraint + thresholds = tuple(c.get_threshold(p) for p in parameters_in_c) + + # Count the number of zeros + n_zeros = count_zeros(thresholds, batch_related_to_c) + + # When any minimum cardinality is violated + if type_cardinality == "min" and np.any( + len(c.parameters) - n_zeros < c.min_cardinality + ): + return False + + # When any maximum cardinality is violated + if type_cardinality == "max" and np.any( + len(c.parameters) - n_zeros > c.max_cardinality + ): + return False + return True diff --git a/tests/constraints/test_cardinality_constraint_continuous.py b/tests/constraints/test_cardinality_constraint_continuous.py index 69aa71084..5c4780109 100644 --- a/tests/constraints/test_cardinality_constraint_continuous.py +++ b/tests/constraints/test_cardinality_constraint_continuous.py @@ -1,6 +1,8 @@ """Tests for the continuous cardinality constraint.""" +import warnings from itertools import combinations_with_replacement +from warnings import WarningMessage import numpy as np import pandas as pd @@ -10,36 +12,60 @@ ContinuousCardinalityConstraint, ContinuousLinearConstraint, ) -from baybe.parameters import NumericalContinuousParameter -from baybe.recommenders.pure.nonpredictive.sampling import RandomRecommender +from baybe.exceptions import MinimumCardinalityViolatedWarning +from baybe.parameters.numerical import NumericalContinuousParameter +from baybe.recommenders import BotorchRecommender from baybe.searchspace.core import SearchSpace, SubspaceContinuous +from baybe.targets.numerical import NumericalTarget +from baybe.utils.cardinality_constraints import is_cardinality_fulfilled -def _validate_samples( - samples: pd.DataFrame, max_cardinality: int, min_cardinality: int, batch_size: int +def _validate_cardinality_constrained_batch( + subspace_continuous: SubspaceContinuous, + batch: pd.DataFrame, + batch_size: int, + captured_warnings: list[WarningMessage | None], ): - """Validate if cardinality-constrained samples fulfill the necessary conditions. - - Conditions to check: - * Cardinality is in requested range - * The batch contains right number of samples - * The samples are free of duplicates (except all zeros) + """Validate that a cardinality-constrained batch fulfills the necessary conditions. Args: - samples: Samples to check - max_cardinality: Maximum allowed cardinality - min_cardinality: Minimum required cardinality - batch_size: Requested batch size + subspace_continuous: The continuous subspace from which to recommend the points. + batch: Batch to validate. + batch_size: The number of points to be recommended. + captured_warnings: A list of captured warnings. """ - # Assert that cardinality constraint is fulfilled - n_nonzero = np.sum(~np.isclose(samples, 0.0), axis=1) - assert np.all(n_nonzero >= min_cardinality) and np.all(n_nonzero <= max_cardinality) + # Assert that the maximum cardinality constraint is fulfilled + assert is_cardinality_fulfilled(subspace_continuous, batch, "max") - # Assert that we obtain as many samples as requested - assert len(samples) == batch_size + # Check whether the minimum cardinality constraint is fulfilled + is_min_cardinality_fulfilled = is_cardinality_fulfilled( + subspace_continuous, batch, "min" + ) - # If there are duplicates, they must all come from the case cardinality = 0 - assert np.all(samples[samples.duplicated()] == 0.0) + # A warning must be raised when the minimum cardinality constraint is not fulfilled + if not is_min_cardinality_fulfilled: + w_message = "Minimum cardinality constraints are not guaranteed." + assert any(str(w.message) == w_message for w in captured_warnings) + + # Assert that we obtain as many samples as requested + assert batch.shape[0] == batch_size + + # Sanity check: If all recommendations in the batch are identical, something is + # fishy โ€“ unless the cardinality is 0, in which case the entire batch must contain + # zeros. Technically, the probability of getting such a degenerate batch + # is not zero, hence this is not a strict requirement. However, in earlier BoTorch + # versions, this simply happened due to a bug in their sampler: + # https://github.com/pytorch/botorch/issues/2351 + # We thus include this check as a safety net for catching regressions. If it + # turns out the check fails because we observe degenerate batches as actual + # recommendations, we need to invent something smarter. + max_cardinalities = [ + c.max_cardinality for c in subspace_continuous.constraints_cardinality + ] + if len(unique_row := batch.drop_duplicates()) == 1: + assert (unique_row.iloc[0] == 0.0).all() and all( + max_cardinality == 0 for max_cardinality in max_cardinalities + ) # Combinations of cardinalities to be tested @@ -72,11 +98,15 @@ def test_sampling_cardinality_constraint(cardinality_bounds: tuple[int, int]): ), ) - subspace = SubspaceContinuous(parameters=parameters, constraints_nonlin=constraints) - samples = subspace.sample_uniform(BATCH_SIZE) + subspace_continous = SubspaceContinuous( + parameters=parameters, constraints_nonlin=constraints + ) + + with warnings.catch_warnings(record=True) as w: + samples = subspace_continous.sample_uniform(BATCH_SIZE) - # Assert that conditions listed in_validate_samples() are fulfilled - _validate_samples(samples, max_cardinality, min_cardinality, BATCH_SIZE) + # Assert that the constraint conditions hold + _validate_cardinality_constrained_batch(subspace_continous, samples, BATCH_SIZE, w) def test_polytope_sampling_with_cardinality_constraint(): @@ -119,13 +149,17 @@ def test_polytope_sampling_with_cardinality_constraint(): min_cardinality=MIN_CARDINALITY, ), ] - searchspace = SearchSpace.from_product(parameters, constraints) + subspace_continous = SubspaceContinuous.from_product(parameters, constraints) - samples = searchspace.continuous.sample_uniform(BATCH_SIZE) + with warnings.catch_warnings(record=True) as w: + samples = subspace_continous.sample_uniform(BATCH_SIZE) - # Assert that conditions listed in_validate_samples() are fulfilled - _validate_samples( - samples[params_cardinality], MAX_CARDINALITY, MIN_CARDINALITY, BATCH_SIZE + # Assert that the constraint conditions hold + _validate_cardinality_constrained_batch( + subspace_continous, + samples, + BATCH_SIZE, + w, ) # Assert that linear equality constraint is fulfilled @@ -143,32 +177,64 @@ def test_polytope_sampling_with_cardinality_constraint(): ) -@pytest.mark.parametrize( - "parameter_names", [["Conti_finite1", "Conti_finite2", "Conti_finite3"]] -) -@pytest.mark.parametrize("constraint_names", [["ContiConstraint_5"]]) -@pytest.mark.parametrize("batch_size", [5], ids=["b5"]) -def test_random_recommender_with_cardinality_constraint( - parameters: list[NumericalContinuousParameter], - constraints: list[ContinuousCardinalityConstraint], - batch_size: int, -): - """Recommendations generated by a `RandomRecommender` under a cardinality constraint - have the expected number of nonzero elements.""" # noqa +def test_min_cardinality_warning(): + """Providing candidates violating minimum cardinality constraint raises a + warning. + """ # noqa + N_PARAMETERS = 2 + MIN_CARDINALITY = 2 + MAX_CARDINALITY = 2 + BATCH_SIZE = 20 - searchspace = SearchSpace.from_product( - parameters=parameters, constraints=constraints - ) - recommender = RandomRecommender() - recommendations = recommender.recommend( - searchspace=searchspace, - batch_size=batch_size, - ) + lower_bound = -0.5 + upper_bound = 0.5 + stepsize = 0.05 + parameters = [ + NumericalContinuousParameter(name=f"x_{i+1}", bounds=(lower_bound, upper_bound)) + for i in range(N_PARAMETERS) + ] - # Assert that conditions listed in_validate_samples() are fulfilled - _validate_samples( - recommendations, max_cardinality=2, min_cardinality=1, batch_size=batch_size - ) + constraints = [ + ContinuousCardinalityConstraint( + parameters=[p.name for p in parameters], + max_cardinality=MAX_CARDINALITY, + min_cardinality=MIN_CARDINALITY, + ), + ] + + searchspace = SearchSpace.from_product(parameters, constraints) + objective = NumericalTarget("t", "MAX").to_objective() + + # Create a scenario in which + # - The optimum of the target function is at the origin + # - The Botorch recommender is likely to provide candidates at the origin, + # which violates the minimum cardinality constraint. + def custom_target(x1: np.ndarray, x2: np.ndarray) -> np.ndarray: + """A custom target function with maximum at the origin.""" + return -abs(x1) - abs(x2) + + def prepare_measurements() -> pd.DataFrame: + """Prepare measurements.""" + x1 = np.arange(lower_bound, upper_bound + stepsize, stepsize) + # Exclude 0 from the array + X1, X2 = np.meshgrid(x1[abs(x1) > stepsize / 2], x1[abs(x1) > stepsize / 2]) + + return pd.DataFrame( + { + "x_1": X1.flatten(), + "x_2": X2.flatten(), + "t": custom_target(X1.flatten(), X2.flatten()), + } + ) + + with warnings.catch_warnings(record=True) as captured_warnings: + BotorchRecommender().recommend( + BATCH_SIZE, searchspace, objective, prepare_measurements() + ) + assert any( + issubclass(w.category, MinimumCardinalityViolatedWarning) + for w in captured_warnings + ) def test_empty_constraints_after_cardinality_constraint(): diff --git a/tests/constraints/test_constraints_continuous.py b/tests/constraints/test_constraints_continuous.py index 46adc4b21..182f18da0 100644 --- a/tests/constraints/test_constraints_continuous.py +++ b/tests/constraints/test_constraints_continuous.py @@ -1,11 +1,24 @@ """Test for imposing continuous constraints.""" +import warnings + import numpy as np +import pandas as pd import pytest from pytest import param from baybe.constraints import ContinuousLinearConstraint +from baybe.constraints.continuous import ContinuousCardinalityConstraint +from baybe.parameters.numerical import NumericalContinuousParameter +from baybe.recommenders.pure.bayesian.base import BayesianRecommender +from baybe.recommenders.pure.bayesian.botorch import BotorchRecommender +from baybe.recommenders.pure.nonpredictive.sampling import RandomRecommender +from baybe.searchspace import SearchSpace +from baybe.targets.numerical import NumericalTarget from tests.conftest import run_iterations +from tests.constraints.test_cardinality_constraint_continuous import ( + _validate_cardinality_constrained_batch, +) @pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]]) @@ -68,6 +81,43 @@ def test_inequality3(campaign, n_iterations, batch_size): assert (1.0 * res["Conti_finite1"] + 3.0 * res["Conti_finite2"]).le(0.301).all() +@pytest.mark.parametrize("recommender", [RandomRecommender(), BotorchRecommender()]) +def test_cardinality_constraint(recommender): + """Cardinality constraints are taken into account by the recommender.""" + MIN_CARDINALITY = 4 + MAX_CARDINALITY = 7 + BATCH_SIZE = 10 + + parameters = [NumericalContinuousParameter(str(i), (0, 1)) for i in range(10)] + constraints = [ + ContinuousCardinalityConstraint( + [p.name for p in parameters], MIN_CARDINALITY, MAX_CARDINALITY + ) + ] + searchspace = SearchSpace.from_product(parameters, constraints) + + if isinstance(recommender, BayesianRecommender): + objective = NumericalTarget("t", "MAX").to_objective() + measurements = pd.DataFrame(searchspace.continuous.sample_uniform(2)) + measurements["t"] = np.random.random(len(measurements)) + else: + objective = None + measurements = None + + with warnings.catch_warnings(record=True) as w: + recommendation = recommender.recommend( + BATCH_SIZE, searchspace, objective, measurements + ) + + # Assert that the constraint conditions hold + _validate_cardinality_constrained_batch( + searchspace.continuous, + recommendation, + BATCH_SIZE, + w, + ) + + @pytest.mark.slow @pytest.mark.parametrize( "parameter_names", diff --git a/tests/test_searchspace.py b/tests/test_searchspace.py index d127e0698..68bb5dd29 100644 --- a/tests/test_searchspace.py +++ b/tests/test_searchspace.py @@ -268,3 +268,22 @@ def test_cardinality_constraints_with_overlapping_parameters(): ), ), ) + + +def test_cardinality_constraint_with_invalid_parameter_bounds(): + """Imposing a cardinality constraint on a parameter whose range does not include + zero raises an error.""" # noqa + parameters = ( + NumericalContinuousParameter("c1", (0, 1)), + NumericalContinuousParameter("c2", (1, 2)), + ) + with pytest.raises(ValueError, match="must include zero"): + SubspaceContinuous( + parameters=parameters, + constraints_nonlin=( + ContinuousCardinalityConstraint( + parameters=["c1", "c2"], + max_cardinality=1, + ), + ), + ) diff --git a/tests/utils/test_parameters.py b/tests/utils/test_parameters.py new file mode 100644 index 000000000..2c4a3af8c --- /dev/null +++ b/tests/utils/test_parameters.py @@ -0,0 +1,153 @@ +"""Tests for parameter utilities.""" + +import pytest +from pytest import param + +from baybe.parameters import NumericalContinuousParameter +from baybe.parameters.numerical import _FixedNumericalContinuousParameter +from baybe.parameters.utils import activate_parameter +from baybe.utils.interval import Interval + + +def mirror_interval(interval: Interval) -> Interval: + """Return an interval copy mirrored around the origin.""" + return Interval(lower=-interval.upper, upper=-interval.lower) + + +@pytest.mark.parametrize( + ( + "bounds", + "thresholds", + "is_valid", + "expected_bounds", + ), + [ + param( + Interval(lower=-1.0, upper=1.0), + Interval(lower=-1.0, upper=1.0), + False, + None, + id="bounds_on_thresholds", + ), + param( + Interval(lower=-1.0, upper=1.0), + Interval(lower=-1.5, upper=1.5), + False, + None, + id="bounds_in_thresholds", + ), + param( + Interval(lower=-1.0, upper=1.0), + Interval(lower=-1.5, upper=1.0), + False, + None, + id="bounds_in_thresholds_single_side_match", + ), + param( + Interval(lower=-1.0, upper=1.0), + Interval(lower=-0.5, upper=0.5), + True, + Interval(lower=-1.0, upper=1.0), + id="thresholds_in_bounds", + ), + param( + Interval(lower=-1.0, upper=1.0), + Interval(lower=-0.5, upper=1.0), + True, + Interval(lower=-1.0, upper=-0.5), + id="thresholds_in_bounds_single_side_match", + ), + param( + Interval(lower=-0.5, upper=1.0), + Interval(lower=-1.0, upper=0.5), + True, + Interval(lower=0.5, upper=1.0), + id="bounds_intersected_with_thresholds", + ), + param( + Interval(lower=0.0, upper=1.0), + Interval(lower=-1.0, upper=0.0), + True, + Interval(lower=0.0, upper=1.0), + id="bounds_intersected_with_thresholds_on_one_point", + ), + ], +) +@pytest.mark.parametrize("mirror", [False, True]) +def test_activate_parameter( + bounds: Interval, + thresholds: Interval, + is_valid: bool, + expected_bounds: Interval | None, + mirror: bool, +) -> None: + """Test that the utility correctly activate a parameter. + + Args: + bounds: the bounds of the parameter to activate + thresholds: the thresholds of inactive range + is_valid: boolean variable indicating whether a parameter is returned from + activate_parameter + expected_bounds: the bounds of the activated parameter if one is returned + mirror: if true both bounds and thresholds get mirrored + + Returns: + None + """ + if mirror: + bounds = mirror_interval(bounds) + thresholds = mirror_interval(thresholds) + if mirror and is_valid: + expected_bounds = mirror_interval(expected_bounds) + + parameter = NumericalContinuousParameter("parameter", bounds=bounds) + + if is_valid: + activated_parameter = activate_parameter(parameter, thresholds) + assert activated_parameter.bounds == expected_bounds + if expected_bounds.is_degenerate: + assert isinstance(activated_parameter, _FixedNumericalContinuousParameter) + else: + with pytest.raises(ValueError, match="cannot be set active"): + activate_parameter(parameter, thresholds) + + +@pytest.mark.parametrize( + ("bounds", "thresholds", "match"), + [ + param( + Interval(lower=-0.5, upper=0.5), + Interval(lower=0.5, upper=1.0), + "The thresholds must cover zero", + id="invalid_thresholds", + ), + param( + Interval(lower=0.5, upper=1.0), + Interval(lower=-0.5, upper=0.5), + "The parameter bounds must cover zero", + id="invalid_bounds", + ), + ], +) +@pytest.mark.parametrize("mirror", [False, True]) +def test_invalid_activate_parameter( + bounds: Interval, thresholds: Interval, match: str, mirror: bool +) -> None: + """Test that invalid bounds or thresholds are given. + + Args: + bounds: the bounds of the parameter to activate + thresholds: the thresholds of inactive range + match: error message to match + mirror: if true both bounds and thresholds get mirrored + + Returns: + None + """ + if mirror: + bounds = mirror_interval(bounds) + thresholds = mirror_interval(thresholds) + + parameter = NumericalContinuousParameter("parameter", bounds=bounds) + with pytest.raises(ValueError, match=match): + activate_parameter(parameter, thresholds)