From 0b27e02e14b7ffd5071b2ed435d9f8d4026cec49 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 11 Oct 2024 11:19:20 +0200 Subject: [PATCH 01/13] Rework transform signature --- baybe/objectives/base.py | 22 ++++++++++++++++++---- baybe/objectives/desirability.py | 13 ++++++++++--- baybe/objectives/single.py | 11 +++++++++-- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/baybe/objectives/base.py b/baybe/objectives/base.py index 981e0d006..e155db26d 100644 --- a/baybe/objectives/base.py +++ b/baybe/objectives/base.py @@ -29,15 +29,29 @@ def targets(self) -> tuple[Target, ...]: """The targets included in the objective.""" @abstractmethod - def transform(self, data: pd.DataFrame) -> pd.DataFrame: + def transform( + self, + df: pd.DataFrame, + /, + *, + allow_missing: bool = False, + allow_extra: bool = False, + ) -> pd.DataFrame: """Transform target values from experimental to computational representation. Args: - data: The data to be transformed. Must contain columns for all targets - but can contain additional columns. + df: The dataframe to be transformed. The allowed columns of the dataframe + are dictated by the ``allow_missing`` and ``allow_extra`` flags. + allow_missing: If ``False``, each target of the objective must have + (exactly) one corresponding column in the given dataframe. If ``True``, + the dataframe may contain only a subset of target columns. + allow_extra: If ``False``, every column present in the dataframe must + correspond to (exactly) one target of the objective. If ``True``, the + dataframe may contain additional non-target-related columns, which + will be ignored. Returns: - A new dataframe with the targets in computational representation. + A corresponding dataframe with the targets in computational representation. """ diff --git a/baybe/objectives/desirability.py b/baybe/objectives/desirability.py index 3654837db..4ae08a300 100644 --- a/baybe/objectives/desirability.py +++ b/baybe/objectives/desirability.py @@ -140,11 +140,18 @@ def __str__(self) -> str: return to_string("Objective", *fields) @override - def transform(self, data: pd.DataFrame) -> pd.DataFrame: + def transform( + self, + df: pd.DataFrame, + /, + *, + allow_missing: bool = False, + allow_extra: bool = False, + ) -> pd.DataFrame: # Transform all targets individually - transformed = data[[t.name for t in self.targets]].copy() + transformed = df[[t.name for t in self.targets]].copy() for target in self.targets: - transformed[target.name] = target.transform(data[[target.name]]) + transformed[target.name] = target.transform(df[[target.name]]) # Scalarize the transformed targets into desirability values vals = scalarize(transformed.values, self.scalarizer, self._normalized_weights) diff --git a/baybe/objectives/single.py b/baybe/objectives/single.py index 0dd6c3028..9fb053fb4 100644 --- a/baybe/objectives/single.py +++ b/baybe/objectives/single.py @@ -38,8 +38,15 @@ def targets(self) -> tuple[Target, ...]: return (self._target,) @override - def transform(self, data: pd.DataFrame) -> pd.DataFrame: - target_data = data[[self._target.name]].copy() + def transform( + self, + df: pd.DataFrame, + /, + *, + allow_missing: bool = False, + allow_extra: bool = False, + ) -> pd.DataFrame: + target_data = df[[self._target.name]].copy() return self._target.transform(target_data) From 820ab8e513fa80bd87e610268834412fd1cfad01 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 11 Oct 2024 13:38:25 +0200 Subject: [PATCH 02/13] Add validation to objective transformation --- baybe/objectives/desirability.py | 9 ++- baybe/objectives/single.py | 9 ++- baybe/searchspace/continuous.py | 4 +- baybe/searchspace/discrete.py | 4 +- baybe/searchspace/validation.py | 50 +-------------- baybe/surrogates/base.py | 4 +- baybe/utils/validation.py | 61 ++++++++++++++++++- examples/Multi_Target/desirability.py | 2 +- .../validation/test_searchspace_validation.py | 8 +-- 9 files changed, 86 insertions(+), 65 deletions(-) diff --git a/baybe/objectives/desirability.py b/baybe/objectives/desirability.py index 4ae08a300..ae31e604e 100644 --- a/baybe/objectives/desirability.py +++ b/baybe/objectives/desirability.py @@ -21,7 +21,7 @@ from baybe.utils.dataframe import pretty_print_df from baybe.utils.numerical import geom_mean from baybe.utils.plotting import to_string -from baybe.utils.validation import finite_float +from baybe.utils.validation import finite_float, get_transform_objects def _is_all_numerical_targets( @@ -148,8 +148,13 @@ def transform( allow_missing: bool = False, allow_extra: bool = False, ) -> pd.DataFrame: + # Extract the relevant part of the dataframe + targets = get_transform_objects( + self.targets, df, allow_missing=allow_missing, allow_extra=allow_extra + ) + transformed = df[[t.name for t in targets]].copy() + # Transform all targets individually - transformed = df[[t.name for t in self.targets]].copy() for target in self.targets: transformed[target.name] = target.transform(df[[target.name]]) diff --git a/baybe/objectives/single.py b/baybe/objectives/single.py index 9fb053fb4..0537eddd2 100644 --- a/baybe/objectives/single.py +++ b/baybe/objectives/single.py @@ -11,6 +11,7 @@ from baybe.targets.base import Target from baybe.utils.dataframe import pretty_print_df from baybe.utils.plotting import to_string +from baybe.utils.validation import get_transform_objects @define(frozen=True, slots=False) @@ -46,7 +47,13 @@ def transform( allow_missing: bool = False, allow_extra: bool = False, ) -> pd.DataFrame: - target_data = df[[self._target.name]].copy() + # Even for a single target, it is convenient to use the existing machinery + # instead of re-implementing the validation logic + targets = get_transform_objects( + [self._target], df, allow_missing=allow_missing, allow_extra=allow_extra + ) + target_data = df[[t.name for t in targets]].copy() + return self._target.transform(target_data) diff --git a/baybe/searchspace/continuous.py b/baybe/searchspace/continuous.py index c0e67bc96..5d0c20799 100644 --- a/baybe/searchspace/continuous.py +++ b/baybe/searchspace/continuous.py @@ -25,13 +25,13 @@ from baybe.parameters.base import ContinuousParameter from baybe.parameters.utils import get_parameters_from_dataframe, sort_parameters from baybe.searchspace.validation import ( - get_transform_parameters, validate_parameter_names, ) from baybe.serialization import SerialMixin, converter, select_constructor_hook from baybe.utils.basic import to_tuple from baybe.utils.dataframe import pretty_print_df from baybe.utils.plotting import to_string +from baybe.utils.validation import get_transform_objects if TYPE_CHECKING: from baybe.searchspace.core import SearchSpace @@ -343,7 +343,7 @@ def transform( # <<<<<<<<<< Deprecation # Extract the parameters to be transformed - parameters = get_transform_parameters( + parameters = get_transform_objects( self.parameters, df, allow_missing, allow_extra ) diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index d0c3f1fb3..c93c7cde4 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -27,7 +27,6 @@ from baybe.parameters.base import DiscreteParameter, Parameter from baybe.parameters.utils import get_parameters_from_dataframe, sort_parameters from baybe.searchspace.validation import ( - get_transform_parameters, validate_parameter_names, validate_parameters, ) @@ -42,6 +41,7 @@ from baybe.utils.memory import bytes_to_human_readable from baybe.utils.numerical import DTypeFloatNumpy from baybe.utils.plotting import to_string +from baybe.utils.validation import get_transform_objects if TYPE_CHECKING: import polars as pl @@ -753,7 +753,7 @@ def transform( # <<<<<<<<<< Deprecation # Extract the parameters to be transformed - parameters = get_transform_parameters( + parameters = get_transform_objects( self.parameters, df, allow_missing, allow_extra ) diff --git a/baybe/searchspace/validation.py b/baybe/searchspace/validation.py index 5e8e39e71..008d98ce8 100644 --- a/baybe/searchspace/validation.py +++ b/baybe/searchspace/validation.py @@ -1,16 +1,11 @@ """Validation functionality for search spaces.""" -from collections.abc import Collection, Sequence -from typing import TypeVar - -import pandas as pd +from collections.abc import Collection from baybe.exceptions import EmptySearchSpaceError from baybe.parameters import TaskParameter from baybe.parameters.base import Parameter -_T = TypeVar("_T", bound=Parameter) - def validate_parameter_names( # noqa: DOC101, DOC103 parameters: Collection[Parameter], @@ -44,46 +39,3 @@ def validate_parameters(parameters: Collection[Parameter]) -> None: # noqa: DOC # Assert: unique names validate_parameter_names(parameters) - - -def get_transform_parameters( - parameters: Sequence[_T], - df: pd.DataFrame, - allow_missing: bool, - allow_extra: bool, -) -> list[_T]: - """Extract the parameters relevant for transforming a given dataframe. - - Args: - parameters: The parameters to be considered for transformation (provided - they have match in the given dataframe). - df: See :meth:`baybe.searchspace.core.SearchSpace.transform`. - allow_missing: See :meth:`baybe.searchspace.core.SearchSpace.transform`. - allow_extra: See :meth:`baybe.searchspace.core.SearchSpace.transform`. - - Raises: - ValueError: If the given parameters and dataframe are not compatible - under the specified values for the Boolean flags. - - Returns: - The (subset of) parameters that need to be considered for the transformation. - """ - parameter_names = [p.name for p in parameters] - - if (not allow_missing) and (missing := set(parameter_names) - set(df)): # type: ignore[arg-type] - raise ValueError( - f"The search space parameter(s) {missing} cannot be matched against " - f"the provided dataframe. If you want to transform a subset of " - f"parameter columns, explicitly set `allow_missing=True`." - ) - - if (not allow_extra) and (extra := set(df) - set(parameter_names)): - raise ValueError( - f"The provided dataframe column(s) {extra} cannot be matched against" - f"the search space parameters. If you want to transform a dataframe " - f"with additional columns, explicitly set `allow_extra=True'." - ) - - return ( - [p for p in parameters if p.name in df] if allow_missing else list(parameters) - ) diff --git a/baybe/surrogates/base.py b/baybe/surrogates/base.py index f7c8c7063..6a6deee82 100644 --- a/baybe/surrogates/base.py +++ b/baybe/surrogates/base.py @@ -187,7 +187,7 @@ def _make_output_scaler( scaler = factory(1) # TODO: Consider taking into account target boundaries when available - scaler(to_tensor(objective.transform(measurements))) + scaler(to_tensor(objective.transform(measurements, allow_extra=True))) scaler.eval() return scaler @@ -336,7 +336,7 @@ def fit( # Transform and fit train_x_comp_rep, train_y_comp_rep = to_tensor( searchspace.transform(measurements, allow_extra=True), - objective.transform(measurements), + objective.transform(measurements, allow_extra=True), ) train_x = self._input_scaler.transform(train_x_comp_rep) train_y = ( diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py index 6a014c94b..e2896d9df 100644 --- a/baybe/utils/validation.py +++ b/baybe/utils/validation.py @@ -1,11 +1,20 @@ """Validation utilities.""" +from __future__ import annotations + import math -from collections.abc import Callable -from typing import Any +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, TypeVar +import pandas as pd from attrs import Attribute +if TYPE_CHECKING: + from baybe.parameters.base import Parameter + from baybe.targets.base import Target + + _T = TypeVar("_T", bound=Parameter | Target) + def validate_not_nan(self: Any, attribute: Attribute, value: Any) -> None: """Attrs-compatible validator to forbid 'nan' values.""" @@ -66,3 +75,51 @@ def validator(self: Any, attribute: Attribute, value: Any) -> None: non_inf_float = _make_restricted_float_validator(allow_nan=True, allow_inf=False) """Validator for non-infinite floats.""" + + +def get_transform_objects( + objects: Sequence[_T], + df: pd.DataFrame, + allow_missing: bool, + allow_extra: bool, +) -> list[_T]: + """Extract the objects relevant for transforming a given dataframe. + + The passed object are assumed to have corresponding columns in the given dataframe, + identified through their name attribute. The function returns the subset of objects + that have a corresponding column in the dataframe and thus provide the necessary + information for transforming the dataframe. + + Args: + objects: A collection of objects to be considered for transformation (provided + they have a match in the given dataframe). + df: The dataframe to be searched for corresponding columns. + allow_missing: Flag controlling if objects are allowed to have no corresponding + columns in the dataframe. + allow_extra: Flag controlling if the dataframe is allowed to have columns + that have corresponding objects. + + Raises: + ValueError: If the given objects and dataframe are not compatible + under the specified values for the Boolean flags. + + Returns: + The (subset of) objects that need to be considered for the transformation. + """ + names = [p.name for p in objects] + + if (not allow_missing) and (missing := set(names) - set(df)): # type: ignore[arg-type] + raise ValueError( + f"The object(s) named {missing} cannot be matched against " + f"the provided dataframe. If you want to transform a subset of " + f"columns, explicitly set `allow_missing=True`." + ) + + if (not allow_extra) and (extra := set(df) - set(names)): + raise ValueError( + f"The provided dataframe column(s) {extra} cannot be matched against" + f"the given objects. If you want to transform a dataframe " + f"with additional columns, explicitly set `allow_extra=True'." + ) + + return [p for p in objects if p.name in df] if allow_missing else list(objects) diff --git a/examples/Multi_Target/desirability.py b/examples/Multi_Target/desirability.py index 28eca2a40..cb3164635 100644 --- a/examples/Multi_Target/desirability.py +++ b/examples/Multi_Target/desirability.py @@ -109,7 +109,7 @@ rec = campaign.recommend(batch_size=3) add_fake_measurements(rec, campaign.targets) campaign.add_measurements(rec) - desirability = campaign.objective.transform(campaign.measurements) + desirability = campaign.objective.transform(campaign.measurements, allow_extra=True) print(f"\n\n#### ITERATION {kIter+1} ####") print("\nRecommended measurements with fake measured results:\n") diff --git a/tests/validation/test_searchspace_validation.py b/tests/validation/test_searchspace_validation.py index d2209940c..cbb6ea7d2 100644 --- a/tests/validation/test_searchspace_validation.py +++ b/tests/validation/test_searchspace_validation.py @@ -5,7 +5,7 @@ from pytest import param from baybe.parameters.numerical import NumericalDiscreteParameter -from baybe.searchspace.validation import get_transform_parameters +from baybe.utils.validation import get_transform_objects parameters = [NumericalDiscreteParameter("d1", [0, 1])] @@ -15,7 +15,7 @@ [ param( pd.DataFrame(columns=[]), - r"parameter\(s\) \{'d1'\} cannot be matched", + r"object\(s\) named \{'d1'\} cannot be matched", id="missing", ), param( @@ -28,7 +28,7 @@ def test_invalid_transforms(df, match): """Transforming dataframes with incorrect columns raises an error.""" with pytest.raises(ValueError, match=match): - get_transform_parameters(parameters, df, allow_missing=False, allow_extra=False) + get_transform_objects(parameters, df, allow_missing=False, allow_extra=False) @pytest.mark.parametrize( @@ -41,4 +41,4 @@ def test_invalid_transforms(df, match): ) def test_valid_transforms(df, missing, extra): """When providing the appropriate flags, the columns of the dataframe to be transformed can be flexibly chosen.""" # noqa - get_transform_parameters(parameters, df, allow_missing=missing, allow_extra=extra) + get_transform_objects(parameters, df, allow_missing=missing, allow_extra=extra) From cef24bf77c36e85de7db3d58ce77414fa1f997cd Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Fri, 11 Oct 2024 13:44:27 +0200 Subject: [PATCH 03/13] Update CHANGELOG.md --- CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c41223b7..251c7a258 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Breaking Changes +- Passing a dataframe via the `data` argument to `Objective.transform` is no longer + possible. The dataframe must now be passed as positional argument. +- Providing additional dataframe columns to `Objective.transforms` now requires + explicitly passing `allow_extra=True` + +### Added +- `allow_missing` and `allow_extra` keyword arguments to `Objective.transform` + ## [0.11.2] - 2024-10-11 ### Added - `n_restarts` and `n_raw_samples` keywords to configure continuous optimization From ca9e57286102675f3b05088664f1bcd6ed328202 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 10:30:58 +0200 Subject: [PATCH 04/13] Remove unnecessary if clause --- baybe/utils/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py index e2896d9df..7c113b28a 100644 --- a/baybe/utils/validation.py +++ b/baybe/utils/validation.py @@ -122,4 +122,4 @@ def get_transform_objects( f"with additional columns, explicitly set `allow_extra=True'." ) - return [p for p in objects if p.name in df] if allow_missing else list(objects) + return [p for p in objects if p.name in df] From 3626fc8fa7f1cbc419fa9373c80d45224adecbae Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 10:13:31 +0200 Subject: [PATCH 05/13] Fix text Co-authored-by: Martin Fitzner <17951239+Scienfitz@users.noreply.github.com> --- CHANGELOG.md | 2 +- baybe/utils/validation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 251c7a258..df9198808 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Breaking Changes - Passing a dataframe via the `data` argument to `Objective.transform` is no longer possible. The dataframe must now be passed as positional argument. -- Providing additional dataframe columns to `Objective.transforms` now requires +- Providing additional dataframe columns to `Objective.transform` now requires explicitly passing `allow_extra=True` ### Added diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py index 7c113b28a..a89298c2c 100644 --- a/baybe/utils/validation.py +++ b/baybe/utils/validation.py @@ -97,7 +97,7 @@ def get_transform_objects( allow_missing: Flag controlling if objects are allowed to have no corresponding columns in the dataframe. allow_extra: Flag controlling if the dataframe is allowed to have columns - that have corresponding objects. + that have no corresponding objects. Raises: ValueError: If the given objects and dataframe are not compatible From ed0be28deb4ec9858963079b9ca85c801bcbad38 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 10:58:29 +0200 Subject: [PATCH 06/13] Implement deprecation mechanism --- CHANGELOG.md | 12 +++++----- baybe/objectives/desirability.py | 36 ++++++++++++++++++++++++++-- baybe/objectives/single.py | 36 ++++++++++++++++++++++++++-- tests/test_deprecations.py | 40 ++++++++++++++++++++++++++++++-- 4 files changed, 112 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df9198808..79f8c45f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,15 +5,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] -### Breaking Changes -- Passing a dataframe via the `data` argument to `Objective.transform` is no longer - possible. The dataframe must now be passed as positional argument. -- Providing additional dataframe columns to `Objective.transform` now requires - explicitly passing `allow_extra=True` - ### Added - `allow_missing` and `allow_extra` keyword arguments to `Objective.transform` +### Deprecations +- Passing a dataframe via the `data` argument to `Objective.transform` is no longer + possible. The dataframe must now be passed as positional argument. +- The new `allow_extra` flag is automatically set to `True` in `Objective.transform` + when left unspecified + ## [0.11.2] - 2024-10-11 ### Added - `n_restarts` and `n_raw_samples` keywords to configure continuous optimization diff --git a/baybe/objectives/desirability.py b/baybe/objectives/desirability.py index ae31e604e..480f239f4 100644 --- a/baybe/objectives/desirability.py +++ b/baybe/objectives/desirability.py @@ -1,6 +1,7 @@ """Functionality for desirability objectives.""" import gc +import warnings from collections.abc import Callable from functools import cached_property, partial from typing import TypeGuard @@ -142,12 +143,43 @@ def __str__(self) -> str: @override def transform( self, - df: pd.DataFrame, + df: pd.DataFrame | None = None, /, *, allow_missing: bool = False, - allow_extra: bool = False, + allow_extra: bool | None = None, + data: pd.DataFrame | None = None, ) -> pd.DataFrame: + # >>>>>>>>>> Deprecation + if not ((df is None) ^ (data is None)): + raise ValueError( + "Provide the dataframe to be transformed as argument to `df`." + ) + + if data is not None: + df = data + warnings.warn( + "Providing the dataframe via the `data` argument is deprecated and " + "will be removed in a future version. Please pass your dataframe " + "as positional argument instead.", + DeprecationWarning, + ) + + # Mypy does not infer from the above that `df` must be a dataframe here + assert isinstance(df, pd.DataFrame) + + if allow_extra is None: + allow_extra = True + if set(df.columns) - {p.name for p in self.targets}: + warnings.warn( + "For backward compatibility, the new `allow_extra` flag is set " + "to `True` when left unspecified. However, this behavior will be " + "changed in a future version. If you want to invoke the old " + "behavior, please explicitly set `allow_extra=True`.", + DeprecationWarning, + ) + # <<<<<<<<<< Deprecation + # Extract the relevant part of the dataframe targets = get_transform_objects( self.targets, df, allow_missing=allow_missing, allow_extra=allow_extra diff --git a/baybe/objectives/single.py b/baybe/objectives/single.py index 0537eddd2..96810df50 100644 --- a/baybe/objectives/single.py +++ b/baybe/objectives/single.py @@ -1,6 +1,7 @@ """Functionality for single-target objectives.""" import gc +import warnings import pandas as pd from attrs import define, field @@ -41,12 +42,43 @@ def targets(self) -> tuple[Target, ...]: @override def transform( self, - df: pd.DataFrame, + df: pd.DataFrame | None = None, /, *, allow_missing: bool = False, - allow_extra: bool = False, + allow_extra: bool | None = None, + data: pd.DataFrame | None = None, ) -> pd.DataFrame: + # >>>>>>>>>> Deprecation + if not ((df is None) ^ (data is None)): + raise ValueError( + "Provide the dataframe to be transformed as argument to `df`." + ) + + if data is not None: + df = data + warnings.warn( + "Providing the dataframe via the `data` argument is deprecated and " + "will be removed in a future version. Please pass your dataframe " + "as positional argument instead.", + DeprecationWarning, + ) + + # Mypy does not infer from the above that `df` must be a dataframe here + assert isinstance(df, pd.DataFrame) + + if allow_extra is None: + allow_extra = True + if set(df.columns) - {p.name for p in self.targets}: + warnings.warn( + "For backward compatibility, the new `allow_extra` flag is set " + "to `True` when left unspecified. However, this behavior will be " + "changed in a future version. If you want to invoke the old " + "behavior, please explicitly set `allow_extra=True`.", + DeprecationWarning, + ) + # <<<<<<<<<< Deprecation + # Even for a single target, it is convenient to use the existing machinery # instead of re-implementing the validation logic targets = get_transform_objects( diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index 0ad661e48..3cbed3809 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -16,6 +16,7 @@ from baybe.objective import Objective as OldObjective from baybe.objectives.base import Objective from baybe.objectives.desirability import DesirabilityObjective +from baybe.objectives.single import SingleTargetObjective from baybe.parameters.numerical import NumericalContinuousParameter from baybe.recommenders.pure.bayesian import ( BotorchRecommender, @@ -106,12 +107,16 @@ def test_samples_full_factorial(): SubspaceContinuous(parameters).samples_full_factorial(n_points=1) -def test_transform_interface(searchspace): +def test_subspace_transform_interface(searchspace): """Using the deprecated transform interface raises a warning.""" # Not providing `allow_extra` when there are additional columns with pytest.warns(DeprecationWarning): searchspace.discrete.transform( - pd.DataFrame(columns=["additional", *searchspace.discrete.exp_rep.columns]) + pd.DataFrame(columns=["additional", *searchspace.discrete.exp_rep.columns]), + ) + with pytest.warns(DeprecationWarning): + searchspace.continuous.transform( + pd.DataFrame(columns=["additional", *searchspace.discrete.exp_rep.columns]), ) # Passing dataframe via `data` @@ -119,6 +124,10 @@ def test_transform_interface(searchspace): searchspace.discrete.transform( data=searchspace.discrete.exp_rep, allow_extra=True ) + with pytest.warns(DeprecationWarning): + searchspace.continuous.transform( + data=searchspace.discrete.exp_rep, allow_extra=True + ) def test_surrogate_registration(): @@ -179,3 +188,30 @@ def test_constraint_config_deserialization(type_, op): warnings.filterwarnings("ignore", category=DeprecationWarning) actual = Constraint.from_json(config) assert expected == actual, (expected, actual) + + +def test_objective_transform_interface(): + """Using the deprecated transform interface raises a warning.""" + single = SingleTargetObjective(NumericalTarget("A", "MAX")) + desirability = DesirabilityObjective( + [ + NumericalTarget("A", "MAX", (0, 1)), + NumericalTarget("B", "MIN", (-1, 1)), + ] + ) + + # Not providing `allow_extra` when there are additional columns + with pytest.warns(DeprecationWarning): + single.transform( + pd.DataFrame(columns=["A", "additional"]), + ) + with pytest.warns(DeprecationWarning): + desirability.transform( + pd.DataFrame(columns=["A", "B", "additional"]), + ) + + # Passing dataframe via `data` + with pytest.warns(DeprecationWarning): + single.transform(data=pd.DataFrame(columns=["A"]), allow_extra=True) + with pytest.warns(DeprecationWarning): + desirability.transform(data=pd.DataFrame(columns=["A", "B"]), allow_extra=True) From 93d9ff557217f1ba8ee07aa3b71e6bd9b5f29a82 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 14:13:47 +0200 Subject: [PATCH 07/13] Move utility to utils/dataframe.py --- baybe/objectives/desirability.py | 4 +- baybe/objectives/single.py | 3 +- baybe/searchspace/continuous.py | 3 +- baybe/searchspace/discrete.py | 2 +- baybe/utils/dataframe.py | 59 +++++++++++++++++-- baybe/utils/validation.py | 59 +------------------ .../validation/test_searchspace_validation.py | 2 +- 7 files changed, 61 insertions(+), 71 deletions(-) diff --git a/baybe/objectives/desirability.py b/baybe/objectives/desirability.py index 480f239f4..c8f067aff 100644 --- a/baybe/objectives/desirability.py +++ b/baybe/objectives/desirability.py @@ -19,10 +19,10 @@ from baybe.targets.base import Target from baybe.targets.numerical import NumericalTarget from baybe.utils.basic import to_tuple -from baybe.utils.dataframe import pretty_print_df +from baybe.utils.dataframe import get_transform_objects, pretty_print_df from baybe.utils.numerical import geom_mean from baybe.utils.plotting import to_string -from baybe.utils.validation import finite_float, get_transform_objects +from baybe.utils.validation import finite_float def _is_all_numerical_targets( diff --git a/baybe/objectives/single.py b/baybe/objectives/single.py index 96810df50..53bdfc0fe 100644 --- a/baybe/objectives/single.py +++ b/baybe/objectives/single.py @@ -10,9 +10,8 @@ from baybe.objectives.base import Objective from baybe.targets.base import Target -from baybe.utils.dataframe import pretty_print_df +from baybe.utils.dataframe import get_transform_objects, pretty_print_df from baybe.utils.plotting import to_string -from baybe.utils.validation import get_transform_objects @define(frozen=True, slots=False) diff --git a/baybe/searchspace/continuous.py b/baybe/searchspace/continuous.py index 5d0c20799..473d61e17 100644 --- a/baybe/searchspace/continuous.py +++ b/baybe/searchspace/continuous.py @@ -29,9 +29,8 @@ ) from baybe.serialization import SerialMixin, converter, select_constructor_hook from baybe.utils.basic import to_tuple -from baybe.utils.dataframe import pretty_print_df +from baybe.utils.dataframe import get_transform_objects, pretty_print_df from baybe.utils.plotting import to_string -from baybe.utils.validation import get_transform_objects if TYPE_CHECKING: from baybe.searchspace.core import SearchSpace diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index c93c7cde4..e2990a6fb 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -36,12 +36,12 @@ from baybe.utils.dataframe import ( df_drop_single_value_columns, fuzzy_row_match, + get_transform_objects, pretty_print_df, ) from baybe.utils.memory import bytes_to_human_readable from baybe.utils.numerical import DTypeFloatNumpy from baybe.utils.plotting import to_string -from baybe.utils.validation import get_transform_objects if TYPE_CHECKING: import polars as pl diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index e365c5bc3..070fe6008 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -4,11 +4,7 @@ import logging from collections.abc import Collection, Iterable, Sequence -from typing import ( - TYPE_CHECKING, - Literal, - overload, -) +from typing import TYPE_CHECKING, Literal, TypeVar, overload import numpy as np import pandas as pd @@ -21,7 +17,10 @@ if TYPE_CHECKING: from torch import Tensor - from baybe.parameters import Parameter + from baybe.parameters.base import Parameter + from baybe.targets.base import Target + + _T = TypeVar("_T", bound=Parameter | Target) # Logging _logger = logging.getLogger(__name__) @@ -503,3 +502,51 @@ def pretty_print_df( ) str_df = str(str_df) return str_df + + +def get_transform_objects( + objects: Sequence[_T], + df: pd.DataFrame, + allow_missing: bool, + allow_extra: bool, +) -> list[_T]: + """Extract the objects relevant for transforming a given dataframe. + + The passed object are assumed to have corresponding columns in the given dataframe, + identified through their name attribute. The function returns the subset of objects + that have a corresponding column in the dataframe and thus provide the necessary + information for transforming the dataframe. + + Args: + objects: A collection of objects to be considered for transformation (provided + they have a match in the given dataframe). + df: The dataframe to be searched for corresponding columns. + allow_missing: Flag controlling if objects are allowed to have no corresponding + columns in the dataframe. + allow_extra: Flag controlling if the dataframe is allowed to have columns + that have no corresponding objects. + + Raises: + ValueError: If the given objects and dataframe are not compatible + under the specified values for the Boolean flags. + + Returns: + The (subset of) objects that need to be considered for the transformation. + """ + names = [p.name for p in objects] + + if (not allow_missing) and (missing := set(names) - set(df)): # type: ignore[arg-type] + raise ValueError( + f"The object(s) named {missing} cannot be matched against " + f"the provided dataframe. If you want to transform a subset of " + f"columns, explicitly set `allow_missing=True`." + ) + + if (not allow_extra) and (extra := set(df) - set(names)): + raise ValueError( + f"The provided dataframe column(s) {extra} cannot be matched against" + f"the given objects. If you want to transform a dataframe " + f"with additional columns, explicitly set `allow_extra=True'." + ) + + return [p for p in objects if p.name in df] diff --git a/baybe/utils/validation.py b/baybe/utils/validation.py index a89298c2c..a16d018c4 100644 --- a/baybe/utils/validation.py +++ b/baybe/utils/validation.py @@ -3,18 +3,11 @@ from __future__ import annotations import math -from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, TypeVar +from collections.abc import Callable +from typing import Any -import pandas as pd from attrs import Attribute -if TYPE_CHECKING: - from baybe.parameters.base import Parameter - from baybe.targets.base import Target - - _T = TypeVar("_T", bound=Parameter | Target) - def validate_not_nan(self: Any, attribute: Attribute, value: Any) -> None: """Attrs-compatible validator to forbid 'nan' values.""" @@ -75,51 +68,3 @@ def validator(self: Any, attribute: Attribute, value: Any) -> None: non_inf_float = _make_restricted_float_validator(allow_nan=True, allow_inf=False) """Validator for non-infinite floats.""" - - -def get_transform_objects( - objects: Sequence[_T], - df: pd.DataFrame, - allow_missing: bool, - allow_extra: bool, -) -> list[_T]: - """Extract the objects relevant for transforming a given dataframe. - - The passed object are assumed to have corresponding columns in the given dataframe, - identified through their name attribute. The function returns the subset of objects - that have a corresponding column in the dataframe and thus provide the necessary - information for transforming the dataframe. - - Args: - objects: A collection of objects to be considered for transformation (provided - they have a match in the given dataframe). - df: The dataframe to be searched for corresponding columns. - allow_missing: Flag controlling if objects are allowed to have no corresponding - columns in the dataframe. - allow_extra: Flag controlling if the dataframe is allowed to have columns - that have no corresponding objects. - - Raises: - ValueError: If the given objects and dataframe are not compatible - under the specified values for the Boolean flags. - - Returns: - The (subset of) objects that need to be considered for the transformation. - """ - names = [p.name for p in objects] - - if (not allow_missing) and (missing := set(names) - set(df)): # type: ignore[arg-type] - raise ValueError( - f"The object(s) named {missing} cannot be matched against " - f"the provided dataframe. If you want to transform a subset of " - f"columns, explicitly set `allow_missing=True`." - ) - - if (not allow_extra) and (extra := set(df) - set(names)): - raise ValueError( - f"The provided dataframe column(s) {extra} cannot be matched against" - f"the given objects. If you want to transform a dataframe " - f"with additional columns, explicitly set `allow_extra=True'." - ) - - return [p for p in objects if p.name in df] diff --git a/tests/validation/test_searchspace_validation.py b/tests/validation/test_searchspace_validation.py index cbb6ea7d2..ed8d154ce 100644 --- a/tests/validation/test_searchspace_validation.py +++ b/tests/validation/test_searchspace_validation.py @@ -5,7 +5,7 @@ from pytest import param from baybe.parameters.numerical import NumericalDiscreteParameter -from baybe.utils.validation import get_transform_objects +from baybe.utils.dataframe import get_transform_objects parameters = [NumericalDiscreteParameter("d1", [0, 1])] From b31a78218c3078ba97d51d70e2f14f49e28953bf Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 14:21:52 +0200 Subject: [PATCH 08/13] Switch argument order --- baybe/objectives/desirability.py | 2 +- baybe/objectives/single.py | 2 +- baybe/searchspace/continuous.py | 2 +- baybe/searchspace/discrete.py | 2 +- baybe/utils/dataframe.py | 7 ++----- tests/validation/test_searchspace_validation.py | 4 ++-- 6 files changed, 8 insertions(+), 11 deletions(-) diff --git a/baybe/objectives/desirability.py b/baybe/objectives/desirability.py index c8f067aff..ebdad88d8 100644 --- a/baybe/objectives/desirability.py +++ b/baybe/objectives/desirability.py @@ -182,7 +182,7 @@ def transform( # Extract the relevant part of the dataframe targets = get_transform_objects( - self.targets, df, allow_missing=allow_missing, allow_extra=allow_extra + df, self.targets, allow_missing=allow_missing, allow_extra=allow_extra ) transformed = df[[t.name for t in targets]].copy() diff --git a/baybe/objectives/single.py b/baybe/objectives/single.py index 53bdfc0fe..705062dba 100644 --- a/baybe/objectives/single.py +++ b/baybe/objectives/single.py @@ -81,7 +81,7 @@ def transform( # Even for a single target, it is convenient to use the existing machinery # instead of re-implementing the validation logic targets = get_transform_objects( - [self._target], df, allow_missing=allow_missing, allow_extra=allow_extra + df, [self._target], allow_missing=allow_missing, allow_extra=allow_extra ) target_data = df[[t.name for t in targets]].copy() diff --git a/baybe/searchspace/continuous.py b/baybe/searchspace/continuous.py index 473d61e17..96a876578 100644 --- a/baybe/searchspace/continuous.py +++ b/baybe/searchspace/continuous.py @@ -343,7 +343,7 @@ def transform( # Extract the parameters to be transformed parameters = get_transform_objects( - self.parameters, df, allow_missing, allow_extra + df, self.parameters, allow_missing, allow_extra ) # Transform the parameters diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index e2990a6fb..d5d0fa44e 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -754,7 +754,7 @@ def transform( # Extract the parameters to be transformed parameters = get_transform_objects( - self.parameters, df, allow_missing, allow_extra + df, self.parameters, allow_missing, allow_extra ) # If the transformed values are not required, return an empty dataframe diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 070fe6008..ac6897800 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -505,10 +505,7 @@ def pretty_print_df( def get_transform_objects( - objects: Sequence[_T], - df: pd.DataFrame, - allow_missing: bool, - allow_extra: bool, + df: pd.DataFrame, objects: Sequence[_T], allow_missing: bool, allow_extra: bool ) -> list[_T]: """Extract the objects relevant for transforming a given dataframe. @@ -518,9 +515,9 @@ def get_transform_objects( information for transforming the dataframe. Args: + df: The dataframe to be searched for corresponding columns. objects: A collection of objects to be considered for transformation (provided they have a match in the given dataframe). - df: The dataframe to be searched for corresponding columns. allow_missing: Flag controlling if objects are allowed to have no corresponding columns in the dataframe. allow_extra: Flag controlling if the dataframe is allowed to have columns diff --git a/tests/validation/test_searchspace_validation.py b/tests/validation/test_searchspace_validation.py index ed8d154ce..c18e56c24 100644 --- a/tests/validation/test_searchspace_validation.py +++ b/tests/validation/test_searchspace_validation.py @@ -28,7 +28,7 @@ def test_invalid_transforms(df, match): """Transforming dataframes with incorrect columns raises an error.""" with pytest.raises(ValueError, match=match): - get_transform_objects(parameters, df, allow_missing=False, allow_extra=False) + get_transform_objects(df, parameters, allow_missing=False, allow_extra=False) @pytest.mark.parametrize( @@ -41,4 +41,4 @@ def test_invalid_transforms(df, match): ) def test_valid_transforms(df, missing, extra): """When providing the appropriate flags, the columns of the dataframe to be transformed can be flexibly chosen.""" # noqa - get_transform_objects(parameters, df, allow_missing=missing, allow_extra=extra) + get_transform_objects(df, parameters, allow_missing=missing, allow_extra=extra) From 9c65ad011557d78c8943bb15d81a4bd11e3563f4 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 14:22:55 +0200 Subject: [PATCH 09/13] Make arguments positional-only and keyword-only --- baybe/searchspace/continuous.py | 2 +- baybe/searchspace/discrete.py | 2 +- baybe/utils/dataframe.py | 7 ++++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/baybe/searchspace/continuous.py b/baybe/searchspace/continuous.py index 96a876578..51fc4726a 100644 --- a/baybe/searchspace/continuous.py +++ b/baybe/searchspace/continuous.py @@ -343,7 +343,7 @@ def transform( # Extract the parameters to be transformed parameters = get_transform_objects( - df, self.parameters, allow_missing, allow_extra + df, self.parameters, allow_missing=allow_missing, allow_extra=allow_extra ) # Transform the parameters diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index d5d0fa44e..88f26fef9 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -754,7 +754,7 @@ def transform( # Extract the parameters to be transformed parameters = get_transform_objects( - df, self.parameters, allow_missing, allow_extra + df, self.parameters, allow_missing=allow_missing, allow_extra=allow_extra ) # If the transformed values are not required, return an empty dataframe diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index ac6897800..5fec30b04 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -505,7 +505,12 @@ def pretty_print_df( def get_transform_objects( - df: pd.DataFrame, objects: Sequence[_T], allow_missing: bool, allow_extra: bool + df: pd.DataFrame, + objects: Sequence[_T], + /, + *, + allow_missing: bool, + allow_extra: bool, ) -> list[_T]: """Extract the objects relevant for transforming a given dataframe. From b50603cb438aba1a1c931630cfb8c16577834553 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 14:40:38 +0200 Subject: [PATCH 10/13] Add defaults for Boolean flags --- baybe/utils/dataframe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 5fec30b04..15f2b638e 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -509,8 +509,8 @@ def get_transform_objects( objects: Sequence[_T], /, *, - allow_missing: bool, - allow_extra: bool, + allow_missing: bool = False, + allow_extra: bool = False, ) -> list[_T]: """Extract the objects relevant for transforming a given dataframe. From 875e84d5707c02bfcd26a69c109bab387bddd468 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 14:42:39 +0200 Subject: [PATCH 11/13] Add deprecation for get_transform_parameters --- CHANGELOG.md | 1 + baybe/searchspace/validation.py | 26 +++++++++++++++++++++++++- tests/test_deprecations.py | 9 +++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79f8c45f5..a4aab98b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 possible. The dataframe must now be passed as positional argument. - The new `allow_extra` flag is automatically set to `True` in `Objective.transform` when left unspecified +- `get_transform_parameters` has been replaced with `get_transform_objects` ## [0.11.2] - 2024-10-11 ### Added diff --git a/baybe/searchspace/validation.py b/baybe/searchspace/validation.py index 008d98ce8..f72c8277f 100644 --- a/baybe/searchspace/validation.py +++ b/baybe/searchspace/validation.py @@ -1,10 +1,17 @@ """Validation functionality for search spaces.""" -from collections.abc import Collection +import warnings +from collections.abc import Collection, Sequence +from typing import TypeVar + +import pandas as pd from baybe.exceptions import EmptySearchSpaceError from baybe.parameters import TaskParameter from baybe.parameters.base import Parameter +from baybe.utils.dataframe import get_transform_objects + +_T = TypeVar("_T", bound=Parameter) def validate_parameter_names( # noqa: DOC101, DOC103 @@ -39,3 +46,20 @@ def validate_parameters(parameters: Collection[Parameter]) -> None: # noqa: DOC # Assert: unique names validate_parameter_names(parameters) + + +def get_transform_parameters( + parameters: Sequence[_T], + df: pd.DataFrame, + allow_missing: bool = False, + allow_extra: bool = False, +) -> list[_T]: + """Deprecated!""" # noqa: D401 + warnings.warn( + f"The function 'get_transform_parameters' has been deprecated and will be " + f"removed in a future version. Use '{get_transform_objects.__name__}' instead.", + DeprecationWarning, + ) + return get_transform_objects( + df, parameters, allow_missing=allow_missing, allow_extra=allow_extra + ) diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index 3cbed3809..01e9352e3 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -23,6 +23,7 @@ SequentialGreedyRecommender, ) from baybe.searchspace.continuous import SubspaceContinuous +from baybe.searchspace.validation import get_transform_parameters from baybe.targets.numerical import NumericalTarget @@ -215,3 +216,11 @@ def test_objective_transform_interface(): single.transform(data=pd.DataFrame(columns=["A"]), allow_extra=True) with pytest.warns(DeprecationWarning): desirability.transform(data=pd.DataFrame(columns=["A", "B"]), allow_extra=True) + + +def test_deprecated_get_transform_parameters(): + """Using the deprecated utility raises a warning.""" + with pytest.warns( + DeprecationWarning, match="'get_transform_parameters' has been deprecated" + ): + get_transform_parameters(pd.DataFrame(), []) From e59276213fd521ce801c75cf3280ea489663c93d Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 14:44:35 +0200 Subject: [PATCH 12/13] Remove unnecessary parentheses --- baybe/objectives/base.py | 4 ++-- baybe/searchspace/core.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/baybe/objectives/base.py b/baybe/objectives/base.py index e155db26d..7d877cd00 100644 --- a/baybe/objectives/base.py +++ b/baybe/objectives/base.py @@ -43,10 +43,10 @@ def transform( df: The dataframe to be transformed. The allowed columns of the dataframe are dictated by the ``allow_missing`` and ``allow_extra`` flags. allow_missing: If ``False``, each target of the objective must have - (exactly) one corresponding column in the given dataframe. If ``True``, + exactly one corresponding column in the given dataframe. If ``True``, the dataframe may contain only a subset of target columns. allow_extra: If ``False``, every column present in the dataframe must - correspond to (exactly) one target of the objective. If ``True``, the + correspond to exactly one target of the objective. If ``True``, the dataframe may contain additional non-target-related columns, which will be ignored. diff --git a/baybe/searchspace/core.py b/baybe/searchspace/core.py index d8e363d49..0014cc9a4 100644 --- a/baybe/searchspace/core.py +++ b/baybe/searchspace/core.py @@ -358,10 +358,10 @@ def transform( The ``None`` default value is for temporary backward compatibility only and will be removed in a future version. allow_missing: If ``False``, each parameter of the space must have - (exactly) one corresponding column in the given dataframe. If ``True``, + exactly one corresponding column in the given dataframe. If ``True``, the dataframe may contain only a subset of parameter columns. allow_extra: If ``False``, every column present in the dataframe must - correspond to (exactly) one parameter of the space. If ``True``, the + correspond to exactly one parameter of the space. If ``True``, the dataframe may contain additional non-parameter-related columns, which will be ignored. The ``None`` default value is for temporary backward compatibility only From a76e008ff0ce8e6df62bd452feec7f30f438ef31 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 14 Oct 2024 14:54:30 +0200 Subject: [PATCH 13/13] Fix docstrings --- baybe/objectives/base.py | 2 +- baybe/searchspace/core.py | 2 +- baybe/utils/dataframe.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/baybe/objectives/base.py b/baybe/objectives/base.py index 7d877cd00..375bb7b55 100644 --- a/baybe/objectives/base.py +++ b/baybe/objectives/base.py @@ -45,7 +45,7 @@ def transform( allow_missing: If ``False``, each target of the objective must have exactly one corresponding column in the given dataframe. If ``True``, the dataframe may contain only a subset of target columns. - allow_extra: If ``False``, every column present in the dataframe must + allow_extra: If ``False``, each column present in the dataframe must correspond to exactly one target of the objective. If ``True``, the dataframe may contain additional non-target-related columns, which will be ignored. diff --git a/baybe/searchspace/core.py b/baybe/searchspace/core.py index 0014cc9a4..b307e3243 100644 --- a/baybe/searchspace/core.py +++ b/baybe/searchspace/core.py @@ -360,7 +360,7 @@ def transform( allow_missing: If ``False``, each parameter of the space must have exactly one corresponding column in the given dataframe. If ``True``, the dataframe may contain only a subset of parameter columns. - allow_extra: If ``False``, every column present in the dataframe must + allow_extra: If ``False``, each column present in the dataframe must correspond to exactly one parameter of the space. If ``True``, the dataframe may contain additional non-parameter-related columns, which will be ignored. diff --git a/baybe/utils/dataframe.py b/baybe/utils/dataframe.py index 15f2b638e..c0e776e6c 100644 --- a/baybe/utils/dataframe.py +++ b/baybe/utils/dataframe.py @@ -514,7 +514,7 @@ def get_transform_objects( ) -> list[_T]: """Extract the objects relevant for transforming a given dataframe. - The passed object are assumed to have corresponding columns in the given dataframe, + The passed objects are assumed to have corresponding columns in the given dataframe, identified through their name attribute. The function returns the subset of objects that have a corresponding column in the dataframe and thus provide the necessary information for transforming the dataframe.