From f47d6e7ba9cb5af67c16fac6fb0d0ae4c550f100 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:02:03 +0100 Subject: [PATCH] Refactor filtering test (#287) --- tests/test_unit/test_filtering.py | 38 +++++++++++++------------------ 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/tests/test_unit/test_filtering.py b/tests/test_unit/test_filtering.py index 1957a770..1bc59348 100644 --- a/tests/test_unit/test_filtering.py +++ b/tests/test_unit/test_filtering.py @@ -157,25 +157,19 @@ def test_filter_on_position( assert not position_filtered.equals(valid_input_dataset.position) -# Expected number of nans in the position array per individual, -# for each dataset -expected_n_nans_in_position_per_indiv = { - "valid_poses_dataset": {0: 0, 1: 0}, - # filtering should not introduce nans if input has no nans - "valid_bboxes_dataset": {0: 0, 1: 0}, - # filtering should not introduce nans if input has no nans - "valid_poses_dataset_with_nan": {0: 7, 1: 0}, - # individual with index 0 has 7 frames with nans in position after - # filtering individual with index 1 has no nans after filtering - "valid_bboxes_dataset_with_nan": {0: 7, 1: 0}, - # individual with index 0 has 7 frames with nans in position after - # filtering individual with index 0 has no nans after filtering -} - - +# Expected number of nans in the position array per +# individual, after applying a filter with window size 3 @pytest.mark.parametrize( - ("valid_dataset, expected_n_nans_in_position_per_indiv"), - [(k, v) for k, v in expected_n_nans_in_position_per_indiv.items()], + ("valid_dataset, expected_nans_in_filtered_position_per_indiv"), + [ + ( + "valid_poses_dataset", + {0: 0, 1: 0}, + ), # filtering should not introduce nans if input has no nans + ("valid_bboxes_dataset", {0: 0, 1: 0}), + ("valid_poses_dataset_with_nan", {0: 7, 1: 0}), + ("valid_bboxes_dataset_with_nan", {0: 7, 1: 0}), + ], ) @pytest.mark.parametrize( ("filter_func, filter_kwargs"), @@ -188,7 +182,7 @@ def test_filter_with_nans_on_position( filter_func, filter_kwargs, valid_dataset, - expected_n_nans_in_position_per_indiv, + expected_nans_in_filtered_position_per_indiv, helpers, request, ): @@ -199,7 +193,7 @@ def test_filter_with_nans_on_position( def _assert_n_nans_in_position_per_individual( valid_input_dataset, position_filtered, - expected_n_nans_in_position_per_indiv, + expected_nans_in_filt_position_per_indiv, ): # compute n nans in position after filtering per individual n_nans_after_filtering_per_indiv = { @@ -210,7 +204,7 @@ def _assert_n_nans_in_position_per_individual( # check number of nans per indiv is as expected for i in range(valid_input_dataset.dims["individuals"]): assert n_nans_after_filtering_per_indiv[i] == ( - expected_n_nans_in_position_per_indiv[i] + expected_nans_in_filt_position_per_indiv[i] * valid_input_dataset.dims["space"] * valid_input_dataset.dims.get("keypoints", 1) ) @@ -225,7 +219,7 @@ def _assert_n_nans_in_position_per_individual( _assert_n_nans_in_position_per_individual( valid_input_dataset, position_filtered, - expected_n_nans_in_position_per_indiv, + expected_nans_in_filtered_position_per_indiv, ) # if input had nans,