From 62bdb468db3291cd733c8d7d7091d238c54a3a8d Mon Sep 17 00:00:00 2001 From: Martin Fitzner Date: Tue, 2 Jul 2024 19:27:16 +0200 Subject: [PATCH] Simplify permutation --- baybe/utils/augmentation.py | 112 ++++++++++++++++-------------------- tests/test_utils.py | 55 ++++++------------ 2 files changed, 67 insertions(+), 100 deletions(-) diff --git a/baybe/utils/augmentation.py b/baybe/utils/augmentation.py index 8cd886c5d..b4b6c5e6e 100644 --- a/baybe/utils/augmentation.py +++ b/baybe/utils/augmentation.py @@ -33,58 +33,53 @@ def _row_in_df(row: pd.Series | pd.DataFrame, df: pd.DataFrame) -> bool: def df_apply_permutation_augmentation( df: pd.DataFrame, - columns: Sequence[str], - dependents: Sequence[Sequence[str]] | None = None, + columns: Sequence[Sequence[str]], ) -> pd.DataFrame: """Augment a dataframe if permutation invariant columns are present. - Indices are preserved so that each augmented row will have the same index as its - original. ``dependent`` columns are augmented in the same order as the ``columns``. - * Original - +---+---+---+---+ - | A | B | C | D | - +===+===+===+===+ - | a | b | x | y | - +---+---+---+---+ - | b | a | x | z | - +---+---+---+---+ - - * Result with ``columns = ["A", "B"]`` - - +---+---+---+---+ - | A | B | C | D | - +===+===+===+===+ - | a | b | x | y | - +---+---+---+---+ - | b | a | x | z | - +---+---+---+---+ - | b | a | x | y | - +---+---+---+---+ - | a | b | x | z | - +---+---+---+---+ - - * Result with ``columns = ["A", "B"]``, ``dependents = [["C"], ["D"]]`` - - +---+---+---+---+ - | A | B | C | D | - +===+===+===+===+ - | a | b | x | y | - +---+---+---+---+ - | b | a | x | z | - +---+---+---+---+ - | b | a | y | x | - +---+---+---+---+ - | a | b | z | x | - +---+---+---+---+ + +----+----+----+----+ + | A1 | A2 | B1 | B2 | + +====+====+====+====+ + | a | b | x | y | + +----+----+----+----+ + | b | a | x | z | + +----+----+----+----+ + + * Result with ``columns = [["A1"], ["A2"]]`` + + +----+----+----+----+ + | A1 | A2 | B1 | B2 | + +====+====+====+====+ + | a | b | x | y | + +----+----+----+----+ + | b | a | x | z | + +----+----+----+----+ + | b | a | x | y | + +----+----+----+----+ + | a | b | x | z | + +----+----+----+----+ + + * Result with ``columns = [["A1", "B1"], ["A2", "B2"]]`` + + +----+----+----+----+ + | A1 | A2 | B1 | B2 | + +====+====+====+====+ + | a | b | x | y | + +----+----+----+----+ + | b | a | x | z | + +----+----+----+----+ + | b | a | y | x | + +----+----+----+----+ + | a | b | z | x | + +----+----+----+----+ Args: df: The dataframe that should be augmented. - columns: The permutation invariant columns. - dependents: Columns that are connected to ``columns`` and should be permuted in - the same manner. Can be multiple per entry in ``affected`` but all must be - of same length. + columns: Sequences of permutation invariant columns. The n'th column in each + sequence will be permuted together with each n'th column in the other + sequences. Returns: The augmented dataframe containing the original one. @@ -94,20 +89,16 @@ def df_apply_permutation_augmentation( ValueError: If entries in ``dependents`` are not of same length. """ # Validation - dependents = dependents or [] - if dependents: - if len(columns) != len(dependents): - raise ValueError( - "When augmenting permutation invariance with dependent columns, " - "'dependents' must have exactly as many entries as 'columns'." - ) - if len({len(d) for d in dependents}) != 1 or len(dependents[0]) < 1: - raise ValueError( - "Augmentation with dependents can only work if the amount of dependent " - "columns provided as entries of 'dependents' is the same for all " - "affected columns. If there are no dependents, set 'dependents' to " - "None." - ) + if len(columns) < 2: + raise ValueError( + "When augmenting permutation invariance, at least two column sequences " + "must be given." + ) + if len({len(seq) for seq in columns}) != 1 or len(columns[0]) < 1: + raise ValueError( + "Permutation augmentation can only work if the amount of columns un each " + "sequence is the same and the sequences are not empty." + ) # Augmentation Loop new_rows: list[pd.DataFrame] = [] @@ -118,10 +109,7 @@ def df_apply_permutation_augmentation( new_row = row.copy() # Permute columns - new_row[columns] = row[[columns[k] for k in perm]] - - # Permute dependent columns - for deps in map(list, zip(*dependents)): + for deps in map(list, zip(*columns)): new_row[deps] = row[[deps[k] for k in perm]] # Check whether the new row is an existing permutation diff --git a/tests/test_utils.py b/tests/test_utils.py index f16b44ab6..a2bc6612e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -127,7 +127,7 @@ def test_invalid_register_hooks(target, hook): @pytest.mark.parametrize( - ("data", "columns", "dependents", "data_expected"), + ("data", "columns", "data_expected"), [ param( # 2 invariant cols and 1 unaffected col { @@ -135,8 +135,7 @@ def test_invalid_register_hooks(target, hook): "B": [2, 2], "C": ["x", "y"], }, - ["A", "B"], - None, + [["A"], ["B"]], { "A": [1, 2, 1, 2], "B": [2, 1, 2, 1], @@ -149,8 +148,7 @@ def test_invalid_register_hooks(target, hook): "A": [1, 1], "B": [2, 2], }, - ["A", "B"], - None, + [["A"], ["B"]], { "A": [1, 1, 2], "B": [2, 2, 1], @@ -163,8 +161,7 @@ def test_invalid_register_hooks(target, hook): "B": [2, 2], "T": ["x", "y"], }, - ["A", "B"], - None, + [["A"], ["B"]], { "A": [1, 1, 2, 2], "B": [2, 2, 1, 1], @@ -178,8 +175,7 @@ def test_invalid_register_hooks(target, hook): "B": [2, 2], "T": ["x", "x"], }, - ["A", "B"], - None, + [["A"], ["B"]], { "A": [1, 2], "B": [2, 1], @@ -187,21 +183,6 @@ def test_invalid_register_hooks(target, hook): }, id="2inv+degen_target+degen", ), - param( # 3 invariant cols - { - "A": [1, 1], - "B": [2, 4], - "C": [3, 5], - }, - ["A", "B", "C"], - None, - { - "A": [1, 1, 2, 2, 3, 3, 1, 1, 4, 4, 5, 5], - "B": [2, 3, 1, 3, 2, 1, 4, 5, 1, 5, 1, 4], - "C": [3, 2, 3, 1, 1, 2, 5, 4, 5, 1, 4, 1], - }, - id="3inv", - ), param( # 3 invariant cols { "A": [1, 1], @@ -209,8 +190,7 @@ def test_invalid_register_hooks(target, hook): "C": [3, 5], "D": ["x", "y"], }, - ["A", "B", "C"], - None, + [["A"], ["B"], ["C"]], { "A": [1, 1, 2, 2, 3, 3, 1, 1, 4, 4, 5, 5], "B": [2, 3, 1, 3, 2, 1, 4, 5, 1, 5, 1, 4], @@ -228,8 +208,7 @@ def test_invalid_register_hooks(target, hook): "Other1": ["A", "B"], "Other2": ["C", "D"], }, - ["Slot1", "Slot2"], - [["Frac1"], ["Frac2"]], + [["Slot1", "Frac1"], ["Slot2", "Frac2"]], { "Slot1": ["s1", "s2", "s2", "s4"], "Slot2": ["s2", "s4", "s1", "s2"], @@ -250,8 +229,7 @@ def test_invalid_register_hooks(target, hook): "Temp2": [50, 60], "Other": ["x", "y"], }, - ["Slot1", "Slot2"], - [["Frac1", "Temp1"], ["Frac2", "Temp2"]], + [["Slot1", "Frac1", "Temp1"], ["Slot2", "Frac2", "Temp2"]], { "Slot1": ["s1", "s2", "s2", "s4"], "Slot2": ["s2", "s4", "s1", "s2"], @@ -265,11 +243,11 @@ def test_invalid_register_hooks(target, hook): ), ], ) -def test_df_permutation_aug(data, columns, dependents, data_expected): +def test_df_permutation_aug(data, columns, data_expected): """Test permutation invariance data augmentation is done correctly.""" # Create all needed dataframes df = pd.DataFrame(data) - df_augmented = df_apply_permutation_augmentation(df, columns, dependents) + df_augmented = df_apply_permutation_augmentation(df, columns) df_expected = pd.DataFrame(data_expected) # Determine equality ignoring row order @@ -287,18 +265,19 @@ def test_df_permutation_aug(data, columns, dependents, data_expected): @pytest.mark.parametrize( - ("columns", "dependents", "msg"), + ("columns", "msg"), [ - param(["A"], [["B"], ["C"]], "exactly as many", id="too_manydependents"), - param(["A", "B"], [[], []], "same for all", id="dep_length_zero"), - param(["A", "B"], [["C"], []], "same for all", id="different_dep_lengths"), + param([], "at least two column sequences", id="no_seqs"), + param([["A"]], "at least two column sequences", id="just_one_seq"), + param([["A"], ["B", "C"]], "sequence is the same", id="different_lengths"), + param([[], []], "sequence is the same", id="empty_seqs"), ], ) -def test_df_permutation_aug_invalid(columns, dependents, msg): +def test_df_permutation_aug_invalid(columns, msg): """Test correct errors for invalid permutation attempts.""" df = pd.DataFrame({"A": [1, 1], "B": [2, 2], "C": ["x", "y"]}) with pytest.raises(ValueError, match=msg): - df_apply_permutation_augmentation(df, columns, dependents) + df_apply_permutation_augmentation(df, columns) @pytest.mark.parametrize(