Skip to content

Commit

Permalink
Simplify permutation
Browse files Browse the repository at this point in the history
  • Loading branch information
Scienfitz committed Jul 3, 2024
1 parent 5d9d0d3 commit 62bdb46
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 100 deletions.
112 changes: 50 additions & 62 deletions baybe/utils/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,58 +33,53 @@ def _row_in_df(row: pd.Series | pd.DataFrame, df: pd.DataFrame) -> bool:

def df_apply_permutation_augmentation(
df: pd.DataFrame,
columns: Sequence[str],
dependents: Sequence[Sequence[str]] | None = None,
columns: Sequence[Sequence[str]],
) -> pd.DataFrame:
"""Augment a dataframe if permutation invariant columns are present.
Indices are preserved so that each augmented row will have the same index as its
original. ``dependent`` columns are augmented in the same order as the ``columns``.
* Original
+---+---+---+---+
| A | B | C | D |
+===+===+===+===+
| a | b | x | y |
+---+---+---+---+
| b | a | x | z |
+---+---+---+---+
* Result with ``columns = ["A", "B"]``
+---+---+---+---+
| A | B | C | D |
+===+===+===+===+
| a | b | x | y |
+---+---+---+---+
| b | a | x | z |
+---+---+---+---+
| b | a | x | y |
+---+---+---+---+
| a | b | x | z |
+---+---+---+---+
* Result with ``columns = ["A", "B"]``, ``dependents = [["C"], ["D"]]``
+---+---+---+---+
| A | B | C | D |
+===+===+===+===+
| a | b | x | y |
+---+---+---+---+
| b | a | x | z |
+---+---+---+---+
| b | a | y | x |
+---+---+---+---+
| a | b | z | x |
+---+---+---+---+
+----+----+----+----+
| A1 | A2 | B1 | B2 |
+====+====+====+====+
| a | b | x | y |
+----+----+----+----+
| b | a | x | z |
+----+----+----+----+
* Result with ``columns = [["A1"], ["A2"]]``
+----+----+----+----+
| A1 | A2 | B1 | B2 |
+====+====+====+====+
| a | b | x | y |
+----+----+----+----+
| b | a | x | z |
+----+----+----+----+
| b | a | x | y |
+----+----+----+----+
| a | b | x | z |
+----+----+----+----+
* Result with ``columns = [["A1", "B1"], ["A2", "B2"]]``
+----+----+----+----+
| A1 | A2 | B1 | B2 |
+====+====+====+====+
| a | b | x | y |
+----+----+----+----+
| b | a | x | z |
+----+----+----+----+
| b | a | y | x |
+----+----+----+----+
| a | b | z | x |
+----+----+----+----+
Args:
df: The dataframe that should be augmented.
columns: The permutation invariant columns.
dependents: Columns that are connected to ``columns`` and should be permuted in
the same manner. Can be multiple per entry in ``affected`` but all must be
of same length.
columns: Sequences of permutation invariant columns. The n'th column in each
sequence will be permuted together with each n'th column in the other
sequences.
Returns:
The augmented dataframe containing the original one.
Expand All @@ -94,20 +89,16 @@ def df_apply_permutation_augmentation(
ValueError: If entries in ``dependents`` are not of same length.
"""
# Validation
dependents = dependents or []
if dependents:
if len(columns) != len(dependents):
raise ValueError(
"When augmenting permutation invariance with dependent columns, "
"'dependents' must have exactly as many entries as 'columns'."
)
if len({len(d) for d in dependents}) != 1 or len(dependents[0]) < 1:
raise ValueError(
"Augmentation with dependents can only work if the amount of dependent "
"columns provided as entries of 'dependents' is the same for all "
"affected columns. If there are no dependents, set 'dependents' to "
"None."
)
if len(columns) < 2:
raise ValueError(
"When augmenting permutation invariance, at least two column sequences "
"must be given."
)
if len({len(seq) for seq in columns}) != 1 or len(columns[0]) < 1:
raise ValueError(
"Permutation augmentation can only work if the amount of columns un each "
"sequence is the same and the sequences are not empty."
)

# Augmentation Loop
new_rows: list[pd.DataFrame] = []
Expand All @@ -118,10 +109,7 @@ def df_apply_permutation_augmentation(
new_row = row.copy()

# Permute columns
new_row[columns] = row[[columns[k] for k in perm]]

# Permute dependent columns
for deps in map(list, zip(*dependents)):
for deps in map(list, zip(*columns)):
new_row[deps] = row[[deps[k] for k in perm]]

# Check whether the new row is an existing permutation
Expand Down
55 changes: 17 additions & 38 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,15 @@ def test_invalid_register_hooks(target, hook):


@pytest.mark.parametrize(
("data", "columns", "dependents", "data_expected"),
("data", "columns", "data_expected"),
[
param( # 2 invariant cols and 1 unaffected col
{
"A": [1, 1],
"B": [2, 2],
"C": ["x", "y"],
},
["A", "B"],
None,
[["A"], ["B"]],
{
"A": [1, 2, 1, 2],
"B": [2, 1, 2, 1],
Expand All @@ -149,8 +148,7 @@ def test_invalid_register_hooks(target, hook):
"A": [1, 1],
"B": [2, 2],
},
["A", "B"],
None,
[["A"], ["B"]],
{
"A": [1, 1, 2],
"B": [2, 2, 1],
Expand All @@ -163,8 +161,7 @@ def test_invalid_register_hooks(target, hook):
"B": [2, 2],
"T": ["x", "y"],
},
["A", "B"],
None,
[["A"], ["B"]],
{
"A": [1, 1, 2, 2],
"B": [2, 2, 1, 1],
Expand All @@ -178,39 +175,22 @@ def test_invalid_register_hooks(target, hook):
"B": [2, 2],
"T": ["x", "x"],
},
["A", "B"],
None,
[["A"], ["B"]],
{
"A": [1, 2],
"B": [2, 1],
"T": ["x", "x"],
},
id="2inv+degen_target+degen",
),
param( # 3 invariant cols
{
"A": [1, 1],
"B": [2, 4],
"C": [3, 5],
},
["A", "B", "C"],
None,
{
"A": [1, 1, 2, 2, 3, 3, 1, 1, 4, 4, 5, 5],
"B": [2, 3, 1, 3, 2, 1, 4, 5, 1, 5, 1, 4],
"C": [3, 2, 3, 1, 1, 2, 5, 4, 5, 1, 4, 1],
},
id="3inv",
),
param( # 3 invariant cols
{
"A": [1, 1],
"B": [2, 4],
"C": [3, 5],
"D": ["x", "y"],
},
["A", "B", "C"],
None,
[["A"], ["B"], ["C"]],
{
"A": [1, 1, 2, 2, 3, 3, 1, 1, 4, 4, 5, 5],
"B": [2, 3, 1, 3, 2, 1, 4, 5, 1, 5, 1, 4],
Expand All @@ -228,8 +208,7 @@ def test_invalid_register_hooks(target, hook):
"Other1": ["A", "B"],
"Other2": ["C", "D"],
},
["Slot1", "Slot2"],
[["Frac1"], ["Frac2"]],
[["Slot1", "Frac1"], ["Slot2", "Frac2"]],
{
"Slot1": ["s1", "s2", "s2", "s4"],
"Slot2": ["s2", "s4", "s1", "s2"],
Expand All @@ -250,8 +229,7 @@ def test_invalid_register_hooks(target, hook):
"Temp2": [50, 60],
"Other": ["x", "y"],
},
["Slot1", "Slot2"],
[["Frac1", "Temp1"], ["Frac2", "Temp2"]],
[["Slot1", "Frac1", "Temp1"], ["Slot2", "Frac2", "Temp2"]],
{
"Slot1": ["s1", "s2", "s2", "s4"],
"Slot2": ["s2", "s4", "s1", "s2"],
Expand All @@ -265,11 +243,11 @@ def test_invalid_register_hooks(target, hook):
),
],
)
def test_df_permutation_aug(data, columns, dependents, data_expected):
def test_df_permutation_aug(data, columns, data_expected):
"""Test permutation invariance data augmentation is done correctly."""
# Create all needed dataframes
df = pd.DataFrame(data)
df_augmented = df_apply_permutation_augmentation(df, columns, dependents)
df_augmented = df_apply_permutation_augmentation(df, columns)
df_expected = pd.DataFrame(data_expected)

# Determine equality ignoring row order
Expand All @@ -287,18 +265,19 @@ def test_df_permutation_aug(data, columns, dependents, data_expected):


@pytest.mark.parametrize(
("columns", "dependents", "msg"),
("columns", "msg"),
[
param(["A"], [["B"], ["C"]], "exactly as many", id="too_manydependents"),
param(["A", "B"], [[], []], "same for all", id="dep_length_zero"),
param(["A", "B"], [["C"], []], "same for all", id="different_dep_lengths"),
param([], "at least two column sequences", id="no_seqs"),
param([["A"]], "at least two column sequences", id="just_one_seq"),
param([["A"], ["B", "C"]], "sequence is the same", id="different_lengths"),
param([[], []], "sequence is the same", id="empty_seqs"),
],
)
def test_df_permutation_aug_invalid(columns, dependents, msg):
def test_df_permutation_aug_invalid(columns, msg):
"""Test correct errors for invalid permutation attempts."""
df = pd.DataFrame({"A": [1, 1], "B": [2, 2], "C": ["x", "y"]})
with pytest.raises(ValueError, match=msg):
df_apply_permutation_augmentation(df, columns, dependents)
df_apply_permutation_augmentation(df, columns)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 62bdb46

Please sign in to comment.