diff --git a/openvino_xai/methods/black_box/aise/classification.py b/openvino_xai/methods/black_box/aise/classification.py index 513aed18..7c1947d1 100644 --- a/openvino_xai/methods/black_box/aise/classification.py +++ b/openvino_xai/methods/black_box/aise/classification.py @@ -57,6 +57,8 @@ def __init__( prepare_model=prepare_model, ) self.bounds = Bounds([0.0, 0.0], [1.0, 1.0]) + self.num_iterations_per_kernel: int | None = None + self.kernel_widths: List[float] | np.ndarray | None = None def generate_saliency_map( # type: ignore self, diff --git a/openvino_xai/methods/black_box/aise/detection.py b/openvino_xai/methods/black_box/aise/detection.py index 1cd4b9a2..32c1f5ed 100644 --- a/openvino_xai/methods/black_box/aise/detection.py +++ b/openvino_xai/methods/black_box/aise/detection.py @@ -57,6 +57,8 @@ def __init__( ) self.deletion = False self.predictions = {} + self.num_iterations_per_kernel: int | None = None + self.divisors: List[float] | np.ndarray | None = None def generate_saliency_map( # type: ignore self, diff --git a/openvino_xai/methods/black_box/rise.py b/openvino_xai/methods/black_box/rise.py index e14c1f35..57b33bc2 100644 --- a/openvino_xai/methods/black_box/rise.py +++ b/openvino_xai/methods/black_box/rise.py @@ -45,6 +45,8 @@ def __init__( super().__init__( model=model, postprocess_fn=postprocess_fn, preprocess_fn=preprocess_fn, device_name=device_name ) + self.num_masks: int | None = None + self.num_cells: int | None = None if prepare_model: self.prepare_model() @@ -84,13 +86,11 @@ def generate_saliency_map( """ data_preprocessed = self.preprocess_fn(data) - num_masks, num_cells = self._preset_parameters(preset, num_masks, num_cells) + self.num_masks, self.num_cells = self._preset_parameters(preset, num_masks, num_cells) saliency_maps = self._run_synchronous_explanation( data_preprocessed, target_indices, - num_masks, - num_cells, prob, seed, ) @@ -134,8 +134,6 @@ def _run_synchronous_explanation( self, data_preprocessed: np.ndarray, target_classes: List[int] | None, - num_masks: int, - num_cells: int, prob: float, seed: int, ) -> np.ndarray: @@ -152,8 +150,8 @@ def _run_synchronous_explanation( rand_generator = np.random.default_rng(seed=seed) saliency_maps = np.zeros((num_targets, input_size[0], input_size[1])) - for _ in tqdm(range(0, num_masks), desc="Explaining in synchronous mode"): - mask = self._generate_mask(input_size, num_cells, prob, rand_generator) + for _ in tqdm(range(0, self.num_masks), desc="Explaining in synchronous mode"): + mask = self._generate_mask(input_size, self.num_cells, prob, rand_generator) # Add channel dimensions for masks if is_bhwc_layout(data_preprocessed): masked = np.expand_dims(mask, 2) * data_preprocessed diff --git a/tests/unit/methods/black_box/test_black_box_method.py b/tests/unit/methods/black_box/test_black_box_method.py index cd42a539..0cee1c51 100644 --- a/tests/unit/methods/black_box/test_black_box_method.py +++ b/tests/unit/methods/black_box/test_black_box_method.py @@ -104,16 +104,22 @@ def test_preset(self, fxt_data_root: Path): self._generate_with_preset(method, Preset.SPEED) toc = time.time() time_speed = toc - tic + assert method.num_iterations_per_kernel == 20 + assert np.all(method.kernel_widths == np.array([0.1, 0.175, 0.25])) tic = time.time() self._generate_with_preset(method, Preset.BALANCE) toc = time.time() time_balance = toc - tic + assert method.num_iterations_per_kernel == 50 + assert np.all(method.kernel_widths == np.array([0.1, 0.175, 0.25])) tic = time.time() self._generate_with_preset(method, Preset.QUALITY) toc = time.time() time_quality = toc - tic + assert method.num_iterations_per_kernel == 50 + np.testing.assert_allclose(method.kernel_widths, np.array([0.075, 0.11875, 0.1625, 0.20625, 0.25])) assert time_speed < time_balance < time_quality @@ -171,16 +177,22 @@ def test_preset(self, fxt_data_root: Path): self._generate_with_preset(method, Preset.SPEED) toc = time.time() time_speed = toc - tic + assert method.num_iterations_per_kernel == 20 + assert np.all(method.divisors == np.array([7.0, 4.0, 1.0])) tic = time.time() self._generate_with_preset(method, Preset.BALANCE) toc = time.time() time_balance = toc - tic + assert method.num_iterations_per_kernel == 50 + assert np.all(method.divisors == np.array([7.0, 4.0, 1.0])) tic = time.time() self._generate_with_preset(method, Preset.QUALITY) toc = time.time() time_quality = toc - tic + assert method.num_iterations_per_kernel == 50 + assert np.all(method.divisors == np.array([8.0, 6.25, 4.5, 2.75, 1.0])) assert time_speed < time_balance < time_quality @@ -227,16 +239,22 @@ def test_preset(self, fxt_data_root: Path): self._generate_with_preset(method, Preset.SPEED) toc = time.time() time_speed = toc - tic + assert method.num_masks == 1000 + assert method.num_cells == 4 tic = time.time() self._generate_with_preset(method, Preset.BALANCE) toc = time.time() time_balance = toc - tic + assert method.num_masks == 5000 + assert method.num_cells == 8 tic = time.time() self._generate_with_preset(method, Preset.QUALITY) toc = time.time() time_quality = toc - tic + assert method.num_masks == 10_000 + assert method.num_cells == 12 assert time_speed < time_balance < time_quality