Skip to content

Commit

Permalink
Refactor filtering test (#287)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig authored Aug 29, 2024
1 parent 9612951 commit f47d6e7
Showing 1 changed file with 16 additions and 22 deletions.
38 changes: 16 additions & 22 deletions tests/test_unit/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,25 +157,19 @@ def test_filter_on_position(
assert not position_filtered.equals(valid_input_dataset.position)


# Expected number of nans in the position array per individual,
# for each dataset
expected_n_nans_in_position_per_indiv = {
"valid_poses_dataset": {0: 0, 1: 0},
# filtering should not introduce nans if input has no nans
"valid_bboxes_dataset": {0: 0, 1: 0},
# filtering should not introduce nans if input has no nans
"valid_poses_dataset_with_nan": {0: 7, 1: 0},
# individual with index 0 has 7 frames with nans in position after
# filtering individual with index 1 has no nans after filtering
"valid_bboxes_dataset_with_nan": {0: 7, 1: 0},
# individual with index 0 has 7 frames with nans in position after
# filtering individual with index 0 has no nans after filtering
}


# Expected number of nans in the position array per
# individual, after applying a filter with window size 3
@pytest.mark.parametrize(
("valid_dataset, expected_n_nans_in_position_per_indiv"),
[(k, v) for k, v in expected_n_nans_in_position_per_indiv.items()],
("valid_dataset, expected_nans_in_filtered_position_per_indiv"),
[
(
"valid_poses_dataset",
{0: 0, 1: 0},
), # filtering should not introduce nans if input has no nans
("valid_bboxes_dataset", {0: 0, 1: 0}),
("valid_poses_dataset_with_nan", {0: 7, 1: 0}),
("valid_bboxes_dataset_with_nan", {0: 7, 1: 0}),
],
)
@pytest.mark.parametrize(
("filter_func, filter_kwargs"),
Expand All @@ -188,7 +182,7 @@ def test_filter_with_nans_on_position(
filter_func,
filter_kwargs,
valid_dataset,
expected_n_nans_in_position_per_indiv,
expected_nans_in_filtered_position_per_indiv,
helpers,
request,
):
Expand All @@ -199,7 +193,7 @@ def test_filter_with_nans_on_position(
def _assert_n_nans_in_position_per_individual(
valid_input_dataset,
position_filtered,
expected_n_nans_in_position_per_indiv,
expected_nans_in_filt_position_per_indiv,
):
# compute n nans in position after filtering per individual
n_nans_after_filtering_per_indiv = {
Expand All @@ -210,7 +204,7 @@ def _assert_n_nans_in_position_per_individual(
# check number of nans per indiv is as expected
for i in range(valid_input_dataset.dims["individuals"]):
assert n_nans_after_filtering_per_indiv[i] == (
expected_n_nans_in_position_per_indiv[i]
expected_nans_in_filt_position_per_indiv[i]
* valid_input_dataset.dims["space"]
* valid_input_dataset.dims.get("keypoints", 1)
)
Expand All @@ -225,7 +219,7 @@ def _assert_n_nans_in_position_per_individual(
_assert_n_nans_in_position_per_individual(
valid_input_dataset,
position_filtered,
expected_n_nans_in_position_per_indiv,
expected_nans_in_filtered_position_per_indiv,
)

# if input had nans,
Expand Down

0 comments on commit f47d6e7

Please sign in to comment.