Skip to content

Commit

Permalink
Extract filter_df function
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianSosic committed Nov 18, 2024
1 parent 23cc131 commit 0987d10
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 47 deletions.
27 changes: 7 additions & 20 deletions baybe/campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
telemetry_record_value,
)
from baybe.utils.boolean import eq_dataframe
from baybe.utils.dataframe import fuzzy_row_match
from baybe.utils.dataframe import filter_df, fuzzy_row_match
from baybe.utils.plotting import to_string

if TYPE_CHECKING:
Expand Down Expand Up @@ -309,34 +309,21 @@ def toggle_discrete_candidates(
Args:
filter: A dataframe specifying the filtering mechanism to determine the
candidates subset to be in-/excluded. The subset is determined via a
join (see ``anti`` argument for details) with the discrete space.
candidates subset to be in-/excluded. For details, see
:func:`baybe.utils.dataframe.filter_df`.
exclude: If ``True``, the specified candidates are excluded.
If ``False``, the candidates are considered for recommendation.
anti: If ``False``, the filter determines the points to be affected (i.e.
selection via regular join). If ``True``, the filtering mechanism is
inverted in that only the points passing the filter are unaffected (i.e.
selection via anti-join).
anti: Boolean flag deciding if the points specified by the filter or their
complement is to be kept. For details, see
:func:`baybe.utils.dataframe.filter_df`.
dry_run: If ``True``, the target subset is only extracted but not
affected. If ``False``, the candidate set is updated correspondingly.
Useful for setting up the correct filtering mechanism.
Returns:
The discrete candidate set passing through the specified filter.
"""
exp_rep = self.searchspace.discrete.exp_rep
index_name = exp_rep.index.name

# Identify points to be dropped
points = pd.merge(
exp_rep.reset_index(names="_df_index"), filter, how="left", indicator=True
).set_index("_df_index")
to_drop = points["_merge"] == ("both" if anti else "left_only")

# Drop the points
points.drop(index=points[to_drop].index, inplace=True)
points.drop("_merge", axis=1, inplace=True)
points.index.name = index_name
points = filter_df(self.searchspace.discrete.exp_rep, filter, anti)

if not dry_run:
self._searchspace_metadata.loc[points.index, _EXCLUDED] = exclude
Expand Down
61 changes: 61 additions & 0 deletions baybe/utils/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,64 @@ def get_transform_objects(
)

return [p for p in objects if p.name in df]


def filter_df(
df: pd.DataFrame, filter: pd.DataFrame, anti: bool = False
) -> pd.DataFrame:
"""Filter a dataframe based on a second dataframe defining filtering conditions.
Filtering is done via a join (see ``anti`` argument for details) between the
input dataframe and the filter dataframe.
Args:
df: The dataframe to be filtered.
filter: The dataframe defining the filtering conditions.
anti: If ``False``, the filter dataframe determines the rows to be removed
(i.e. selection via regular join). If ``True``, the filtering mechanism is
inverted in that only the points passing the filter are kept (i.e. selection
via anti-join).
Returns:
A new dataframe containing the result of the filtering process.
Examples:
>>> df = pd.DataFrame(
[[0, "a"], [0, "b"], [1, "a"], [1, "b"]],
columns=["num", "cat"]
)
>>> df
num cat
0 0 a
1 0 b
2 1 a
3 1 b
>>> filter_df(df, pd.DataFrame([0], columns=["num"]), anti=False)
num cat
0 0 a
1 0 b
>>> filter_df(df, pd.DataFrame([0], columns=["num"]), anti=True)
num cat
2 1 a
3 1 b
"""
# Remember original index name
index_name = df.index.name

# Identify rows to be dropped
out = pd.merge(
df.reset_index(names="_df_index"), filter, how="left", indicator=True
).set_index("_df_index")
to_drop = out["_merge"] == ("both" if anti else "left_only")

# Drop the points
out.drop(index=out[to_drop].index, inplace=True)
out.drop("_merge", axis=1, inplace=True)

# Restore original index name
out.index.name = index_name

return out
36 changes: 9 additions & 27 deletions tests/test_campaign.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pandas as pd
import pytest
from pandas.testing import assert_frame_equal
from pytest import param

from baybe.campaign import _EXCLUDED, Campaign
Expand Down Expand Up @@ -37,25 +36,10 @@ def test_get_surrogate(campaign, n_iterations, batch_size):
assert model is not None, "Something went wrong during surrogate model extraction."


@pytest.mark.parametrize(
("anti", "expected"),
[
(
False,
pd.DataFrame(columns=["a", "b"], data=[[0.0, 3.0], [0.0, 4.0], [0.0, 5.0]]),
),
(
True,
pd.DataFrame(columns=["a", "b"], data=[[1.0, 3.0], [1.0, 4.0], [1.0, 5.0]]),
),
],
ids=["regular", "anti"],
)
@pytest.mark.parametrize("anti", [False, True], ids=["regular", "anti"])
@pytest.mark.parametrize("exclude", [True, False], ids=["exclude", "include"])
def test_candidate_filter(exclude, anti, expected):
"""The candidate filter extracts the correct subset of points and the campaign
metadata is updated accordingly.""" # noqa

def test_candidate_toggling(exclude, anti):
"""Toggling discrete candidates updates the campaign metadata accordingly."""
subspace = SubspaceDiscrete.from_product(
[
NumericalDiscreteParameter("a", [0, 1]),
Expand All @@ -68,16 +52,14 @@ def test_candidate_filter(exclude, anti, expected):
campaign._searchspace_metadata[_EXCLUDED] = not exclude

# Toggle the candidates
df = campaign.toggle_discrete_candidates(
pd.DataFrame({"a": [0]}), exclude, anti=anti
)
campaign.toggle_discrete_candidates(pd.DataFrame({"a": [0]}), exclude, anti=anti)

# Assert that the filtering is correct
rows = pd.merge(df.reset_index(), expected).set_index("index")
assert_frame_equal(df, rows, check_names=False)
# Extract row indices of candidates whose metadata should have been toggled
matches = campaign.searchspace.discrete.exp_rep["a"] == 0
idx = matches.index[~matches] if anti else matches.index[matches]

# Assert that metadata is set correctly
target = campaign._searchspace_metadata.loc[rows.index, _EXCLUDED]
other = campaign._searchspace_metadata[_EXCLUDED].drop(index=rows.index)
target = campaign._searchspace_metadata.loc[idx, _EXCLUDED]
other = campaign._searchspace_metadata[_EXCLUDED].drop(index=idx)
assert all(target == exclude) # must contain the updated values
assert all(other != exclude) # must contain the original values

0 comments on commit 0987d10

Please sign in to comment.