-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Draft FilteredSubspaceDiscrete class
- Loading branch information
1 parent
d9448f8
commit c95fc18
Showing
2 changed files
with
35 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,45 @@ | ||
"""Search spaces with metadata.""" | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
import pandas as pd | ||
from attrs import asdict, define, field | ||
from attrs import asdict, cmp_using, define, field | ||
from attrs.validators import instance_of | ||
from typing_extensions import Self, override | ||
|
||
from baybe.searchspace import SubspaceDiscrete | ||
from baybe.utils.boolean import eq_dataframe | ||
|
||
|
||
@define | ||
class AnnotatedSubspaceDiscrete(SubspaceDiscrete): | ||
"""An annotated search space carrying additional metadata.""" | ||
|
||
metadata: pd.DataFrame = field(kw_only=True, eq=eq_dataframe) | ||
"""The metadata.""" | ||
|
||
allow_repeated_recommendations: bool = field(kw_only=True) | ||
"""Allow to make recommendations that were already recommended earlier. | ||
This only has an influence in discrete search spaces.""" | ||
|
||
allow_recommending_already_measured: bool = field(kw_only=True) | ||
"""Allow to make recommendations that were measured previously. | ||
This only has an influence in discrete search spaces.""" | ||
class FilteredSubspaceDiscrete(SubspaceDiscrete): | ||
"""A filtered search space representing a reduced candidate set.""" | ||
|
||
mask: npt.NDArray[np.bool_] = field( | ||
validator=instance_of(np.ndarray), | ||
kw_only=True, | ||
eq=cmp_using(eq=np.array_equal), | ||
) | ||
"""The filtering mask. ``True`` denote elements to be kept.""" | ||
|
||
@mask.validator | ||
def _validate_mask(self, _, value) -> None: | ||
if not len(value) == len(self.exp_rep): | ||
raise ValueError("Filter mask must match the size of the space.") | ||
if not value.dtype == bool: | ||
raise ValueError("Filter mask must only contain Boolean values.") | ||
|
||
@classmethod | ||
def from_subspace( | ||
cls, | ||
subspace: SubspaceDiscrete, | ||
*, | ||
metadata: pd.DataFrame, | ||
allow_repeated_recommendations: bool, | ||
allow_recommending_already_measured: bool, | ||
cls, subspace: SubspaceDiscrete, mask: npt.NDArray[np.bool_] | ||
) -> Self: | ||
"""Annotate an existing subspace with metadata.""" | ||
"""Filter an existing subspace.""" | ||
return cls( | ||
**asdict(subspace, filter=lambda attr, _: attr.init, recurse=False), | ||
metadata=metadata, | ||
allow_repeated_recommendations=allow_repeated_recommendations, | ||
allow_recommending_already_measured=allow_recommending_already_measured, | ||
mask=mask, | ||
) | ||
|
||
@override | ||
def get_candidates(self) -> tuple[pd.DataFrame, pd.DataFrame]: | ||
from baybe.campaign import _EXCLUDED, _MEASURED, _RECOMMENDED | ||
|
||
# Exclude parts marked by metadata | ||
mask_todrop = self.metadata[_EXCLUDED].copy() | ||
if not self.allow_repeated_recommendations: | ||
mask_todrop |= self.metadata[_RECOMMENDED] | ||
if not self.allow_recommending_already_measured: | ||
mask_todrop |= self.metadata[_MEASURED] | ||
|
||
mask_todrop = self._excluded.copy() | ||
mask_todrop |= ~self.mask | ||
return self.exp_rep.loc[~mask_todrop], self.comp_rep.loc[~mask_todrop] |