diff --git a/openvino_xai/explainer/explainer.py b/openvino_xai/explainer/explainer.py index a27fc74b..9643ae71 100644 --- a/openvino_xai/explainer/explainer.py +++ b/openvino_xai/explainer/explainer.py @@ -34,6 +34,7 @@ class ExplainMode(Enum): Contains the following values: WHITEBOX - The model is explained in white box mode, i.e. XAI branch is getting inserted into the model graph. BLACKBOX - The model is explained in black box model. + AUTO - The model is explained in the white-box mode first, if fails - black-box mode will run. """ WHITEBOX = "whitebox" diff --git a/openvino_xai/explainer/explanation.py b/openvino_xai/explainer/explanation.py index 28c9bc42..718a0e63 100644 --- a/openvino_xai/explainer/explanation.py +++ b/openvino_xai/explainer/explanation.py @@ -188,7 +188,7 @@ def save( map_to_save = cv2.cvtColor(map_to_save, code=cv2.COLOR_RGB2BGR) if isinstance(target_idx, str): target_name = "activation_map" - elif self.label_names and isinstance(target_idx, np.int64) and self.task != Task.DETECTION: + elif self.label_names and isinstance(target_idx, (int, np.int64)) and self.task != Task.DETECTION: target_name = self.label_names[target_idx] else: target_name = str(target_idx) @@ -261,7 +261,12 @@ def _plot_matplotlib(self, checked_targets: list[int | str], num_cols: int) -> N map_to_plot = self.saliency_map[target_index] - axes[i].imshow(map_to_plot) + if map_to_plot.ndim == 3: + axes[i].imshow(map_to_plot) + elif map_to_plot.ndim == 2: + axes[i].imshow(map_to_plot, cmap="gray") + else: + raise ValueError(f"Saliency map expected to be 3 or 2-dimensional, but got {map_to_plot.ndim}.") axes[i].axis("off") # Hide the axis axes[i].set_title(f"Class {label_name}") diff --git a/openvino_xai/explainer/visualizer.py b/openvino_xai/explainer/visualizer.py index 32c5b3d4..825dfc50 100644 --- a/openvino_xai/explainer/visualizer.py +++ b/openvino_xai/explainer/visualizer.py @@ -174,14 +174,14 @@ def visualize( # Convert back to dict return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, indices_to_return) - @staticmethod def _put_classification_info( + self, saliency_map_np: np.ndarray, indices: List[int], label_names: List[str] | None, predictions: Dict[int, Prediction] | None, ) -> None: - corner_location = 3, 17 + offset = 3 for smap, target_index in zip(range(len(saliency_map_np)), indices): label = label_names[target_index] if label_names else str(target_index) if predictions and target_index in predictions: @@ -189,18 +189,19 @@ def _put_classification_info( if score: label = f"{label}|{score:.2f}" + font_scale, text_height = self._fit_text_to_image(label, offset, saliency_map_np[smap].shape[1]) cv2.putText( saliency_map_np[smap], label, - org=corner_location, - fontFace=1, - fontScale=1.3, + org=(offset, text_height + offset), + fontFace=2, + fontScale=font_scale, color=(255, 0, 0), - thickness=2, + thickness=1, ) - @staticmethod def _put_detection_info( + self, saliency_map_np: np.ndarray, indices: List[int], label_names: List[str] | None, @@ -209,6 +210,7 @@ def _put_detection_info( if not predictions: return + offset = 7 for smap, target_index in zip(range(len(saliency_map_np)), indices): saliency_map = saliency_map_np[smap] label_index = predictions[target_index].label @@ -220,17 +222,40 @@ def _put_detection_info( label = label_names[label_index] if label_names else label_index label_score = f"{label}|{score:.2f}" - box_location = int(x1), int(y1 - 5) + + font_scale, _ = self._fit_text_to_image(label_score, x1, saliency_map.shape[1]) + box_location = x1, y1 - offset cv2.putText( saliency_map, label_score, org=box_location, - fontFace=1, - fontScale=1.3, + fontFace=2, + fontScale=font_scale, color=(255, 0, 0), - thickness=2, + thickness=1, ) + @staticmethod + def _fit_text_to_image( + text: str, + x_start: int, + image_width: int, + font_scale: float = 1.0, + thickness: int = 1, + ) -> Tuple[float, int]: + font_face = 2 + max_width = image_width - 5 + while True: + text_size, _ = cv2.getTextSize(text, font_face, font_scale, thickness) + text_width, text_height = text_size + + if x_start + text_width <= max_width: + return font_scale, text_height + + font_scale -= 0.1 + if abs(font_scale - 0.1) < 0.001: + return font_scale, text_height + @staticmethod def _apply_scaling(explanation: Explanation, saliency_map_np: np.ndarray) -> np.ndarray: if explanation.layout not in GRAY_LAYOUTS: diff --git a/openvino_xai/methods/black_box/aise/base.py b/openvino_xai/methods/black_box/aise/base.py index d384077d..2aa5b526 100644 --- a/openvino_xai/methods/black_box/aise/base.py +++ b/openvino_xai/methods/black_box/aise/base.py @@ -10,7 +10,7 @@ import openvino.runtime as ov from scipy.optimize import direct -from openvino_xai.common.utils import IdentityPreprocessFN +from openvino_xai.common.utils import IdentityPreprocessFN, is_bhwc_layout from openvino_xai.methods.black_box.base import BlackBoxXAIMethod @@ -92,6 +92,8 @@ def _objective_function(self, args) -> float: kernel_mask = self._mask_generator.generate_kernel_mask(kernel_params) kernel_mask = np.clip(kernel_mask, 0, 1) + if is_bhwc_layout(self.data_preprocessed): + kernel_mask = np.expand_dims(kernel_mask, 2) pred_loss_preserve = 0.0 if self.preservation: diff --git a/tests/unit/explainer/test_explanation.py b/tests/unit/explainer/test_explanation.py index fed82447..4a49043a 100644 --- a/tests/unit/explainer/test_explanation.py +++ b/tests/unit/explainer/test_explanation.py @@ -12,6 +12,14 @@ from tests.unit.explainer.test_explanation_utils import VOC_NAMES SALIENCY_MAPS = (np.random.rand(1, 20, 5, 5) * 255).astype(np.uint8) +SALIENCY_MAPS_DICT = { + 0: (np.random.rand(5, 5, 3) * 255).astype(np.uint8), + 2: (np.random.rand(5, 5, 3) * 255).astype(np.uint8), +} +SALIENCY_MAPS_DICT_EXCEPTION = { + 0: (np.random.rand(5, 5, 3, 2) * 255).astype(np.uint8), + 2: (np.random.rand(5, 5, 3, 2) * 255).astype(np.uint8), +} SALIENCY_MAPS_IMAGE = (np.random.rand(1, 5, 5) * 255).astype(np.uint8) @@ -106,7 +114,7 @@ def test_plot(self, mocker, caplog): # Update the num columns for the matplotlib visualization grid explanation.plot(backend="matplotlib", num_columns=1) - # Class index that is not in saliency maps will be ommitted with message + # Class index that is not in saliency maps will be omitted with message with caplog.at_level(logging.INFO): explanation.plot([0, 3], backend="matplotlib") assert "Provided class index 3 is not available among saliency maps." in caplog.text @@ -123,3 +131,13 @@ def test_plot(self, mocker, caplog): # Plot activation map explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_IMAGE, label_names=None) explanation.plot() + + # Plot colored map + explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_DICT, label_names=None) + explanation.plot() + + # Plot wrong map shape + with pytest.raises(Exception) as exc_info: + explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_DICT_EXCEPTION, label_names=None) + explanation.plot() + assert str(exc_info.value) == "Saliency map expected to be 3 or 2-dimensional, but got 4." diff --git a/tests/unit/explainer/test_visualization.py b/tests/unit/explainer/test_visualization.py index 9a59fd0a..d177b679 100644 --- a/tests/unit/explainer/test_visualization.py +++ b/tests/unit/explainer/test_visualization.py @@ -10,6 +10,11 @@ from openvino_xai.explainer.visualizer import Visualizer, colormap, overlay, resize from openvino_xai.methods.base import Prediction +ORIGINAL_INPUT_IMAGE = [ + np.ones((100, 100, 3)), + np.ones((10, 10, 3)), +] + SALIENCY_MAPS = [ (np.random.rand(1, 5, 5) * 255).astype(np.uint8), (np.random.rand(1, 2, 5, 5) * 255).astype(np.uint8), @@ -97,6 +102,7 @@ def test_overlay(): class TestVisualizer: + @pytest.mark.parametrize("original_input_image", ORIGINAL_INPUT_IMAGE) @pytest.mark.parametrize("saliency_maps", SALIENCY_MAPS) @pytest.mark.parametrize("explain_all_classes", EXPLAIN_ALL_CLASSES) @pytest.mark.parametrize("task", [Task.CLASSIFICATION, Task.DETECTION]) @@ -107,6 +113,7 @@ class TestVisualizer: @pytest.mark.parametrize("overlay_weight", [0.5, 0.3]) def test_visualizer( self, + original_input_image, saliency_maps, explain_all_classes, task, @@ -124,7 +131,6 @@ def test_visualizer( explanation = Explanation(saliency_maps, targets=explain_targets, task=Task.CLASSIFICATION) raw_sal_map_dims = len(explanation.shape) - original_input_image = np.ones((20, 20, 3)) visualizer = Visualizer() explanation = visualizer( explanation=explanation,