From 00ebd2e1e562fa27bdffe6b222f646d6277a8942 Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Wed, 7 Aug 2024 09:59:33 +0900 Subject: [PATCH 1/3] Upgrade OpenVINO to 2024.3.0 (#52) * Update openvino-dev==2023.3 * Update tests * Update change log --- CHANGELOG.md | 2 ++ openvino_xai/methods/white_box/recipro_cam.py | 4 ++-- pyproject.toml | 2 +- tests/func/test_classification_timm_full.py | 24 +++---------------- 4 files changed, 8 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ab6120a..757bffe9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ * Support OpenVINO IR (.xml) / ONNX (.onnx) model file for `Explainer` model * Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models. +* Upgrade OpenVINO to 2024.3.0 ### What's Changed @@ -15,6 +16,7 @@ * Support OV IR / ONNX model file for Explainer by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/47 * Try CNN -> ViT assumption for IR insertion by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/48 * Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models by @negvet in https://github.com/openvinotoolkit/openvino_xai/pull/49 +* Upgrade OpenVINO to 2024.3.0 by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/52 ### Known Issues diff --git a/openvino_xai/methods/white_box/recipro_cam.py b/openvino_xai/methods/white_box/recipro_cam.py index e552049b..da516125 100644 --- a/openvino_xai/methods/white_box/recipro_cam.py +++ b/openvino_xai/methods/white_box/recipro_cam.py @@ -233,10 +233,10 @@ def _get_saliency_map(self, model_clone: ov.Model) -> ov.Node: norm_node_ori = self._get_non_add_node_from_two_nodes(post_target_node_ori) while norm_node_ori.get_type_name() != "Add": if len(norm_node_ori.outputs()) > 1: - raise ValueError + raise ValueError("Number of normalization outputs > 1!") inputs = norm_node_ori.output(0).get_target_inputs() if len(inputs) > 1: - raise ValueError + raise ValueError("Number of normalization inputs > 1!") norm_node_ori = next(iter(inputs)).get_node() # Mosaic feature map after the LayerNorm diff --git a/pyproject.toml b/pyproject.toml index c852cea1..204e507a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" name = "openvino_xai" version = "1.1.0rc0" dependencies = [ - "openvino-dev==2024.2", + "openvino-dev==2024.3", "opencv-python", "scipy", "numpy==1.*", diff --git a/tests/func/test_classification_timm_full.py b/tests/func/test_classification_timm_full.py index e7bc6a5f..73845a64 100644 --- a/tests/func/test_classification_timm_full.py +++ b/tests/func/test_classification_timm_full.py @@ -32,36 +32,19 @@ SUPPORTED_BUT_FAILED_BY_BB_MODELS = {} NOT_SUPPORTED_BY_BB_MODELS = { - "_nfnet_": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", - "convit": "RuntimeError: Couldn't get TorchScript module by tracing.", - "convnext_xxlarge": "RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library.", - "convnextv2_huge": "RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library.", - "deit3_huge": "RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library.", - "dm_nfnet": "openvino._pyopenvino.GeneralFailure: Check 'false' failed at src/frontends/onnx/frontend/src/frontend.cpp:144", - "eca_nfnet": "openvino._pyopenvino.GeneralFailure: Check 'false' failed at src/frontends/onnx/frontend/src/frontend.cpp:144", - "eva_giant": "RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library.", - # "halo": "torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.", - "nf_regnet": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", - "nf_resnet": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", - "nfnet_l0": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", - "regnety_1280": "RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library.", - "regnety_2560": "RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library.", "repvit": "urllib.error.HTTPError: HTTP Error 404: Not Found", - "resnetv2": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", "tf_efficientnet_cc": "torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of convolution for kernel of unknown shape.", "vit_base_r50_s16_224.orig_in21k": "RuntimeError: Error(s) in loading state_dict for VisionTransformer", "vit_gigantic_patch16_224_ijepa.in22k": "RuntimeError: shape '[1, 13, 13, -1]' is invalid for input of size 274560", "vit_huge_patch14_224.orig_in21k": "RuntimeError: Error(s) in loading state_dict for VisionTransformer", "vit_large_patch32_224.orig_in21k": "RuntimeError: Error(s) in loading state_dict for VisionTransformer", - "vit_large_r50_s32": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", - "vit_small_r26_s32": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", - "vit_tiny_r_s16": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", - "volo_": "torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::col2im' to ONNX opset version 14 is not supported.", + "volo_": "RuntimeError: Exception from src/core/src/dimension.cpp:227: Cannot get length of dynamic dimension", } SUPPORTED_BUT_FAILED_BY_WB_MODELS = { - "convformer": "Cannot find output backbone_node in auto mode, please provide target_layer.", "swin": "Only two outputs of the between block Add node supported, but got 1. Try to use black-box.", + "vit_base_patch16_rpn_224": "Number of normalization outputs > 1", + "vit_relpos_medium_patch16_rpn_224": "ValueError in openvino_xai/methods/white_box/recipro_cam.py:215", } NOT_SUPPORTED_BY_WB_MODELS = { @@ -103,7 +86,6 @@ "tnt_s_patch16_224": "Only two outputs of the between block Add node supported, but got 1.", "twins": "One (and only one) of the nodes has to be Add type. But got ShapeOf and Transpose.", # work in CNN mode -> "visformer": "Cannot find output backbone_node in auto mode, please provide target_layer", - "vit_relpos_medium_patch16_rpn_224": "ValueError in openvino_xai/methods/white_box/recipro_cam.py:215", } From 6f42d8df1e49a95e5867003ad7bb6acf1bcb3239 Mon Sep 17 00:00:00 2001 From: Galina Zalesskaya Date: Fri, 9 Aug 2024 12:55:15 +0300 Subject: [PATCH 2/3] Add explanation.plot() (#53) * Add explanation.plot * Add documentation * Update matplotlib * Fixes from comments * Activation map support * Attempt to fix tox unit tests * Comments from negvet * Add warning + grid output * Fix comments --- CHANGELOG.md | 2 + docs/source/user-guide.md | 71 ++++++++++++++++- openvino_xai/explainer/explanation.py | 90 ++++++++++++++++++++++ pyproject.toml | 1 + tests/unit/explanation/test_explanation.py | 38 +++++++++ 5 files changed, 198 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 757bffe9..2907f2d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ * Support OpenVINO IR (.xml) / ONNX (.onnx) model file for `Explainer` model * Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models. * Upgrade OpenVINO to 2024.3.0 +* Add saliency map visualization with explanation.plot() ### What's Changed @@ -17,6 +18,7 @@ * Try CNN -> ViT assumption for IR insertion by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/48 * Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models by @negvet in https://github.com/openvinotoolkit/openvino_xai/pull/49 * Upgrade OpenVINO to 2024.3.0 by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/52 +* Add saliency map visualization with explanation.plot() by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/53 ### Known Issues diff --git a/docs/source/user-guide.md b/docs/source/user-guide.md index 93ee4cf8..23db8cf4 100644 --- a/docs/source/user-guide.md +++ b/docs/source/user-guide.md @@ -13,12 +13,16 @@ Content: - [OpenVINO™ Explainable AI Toolkit User Guide](#openvino-explainable-ai-toolkit-user-guide) - [OpenVINO XAI Architecture](#openvino-xai-architecture) - [`Explainer`: the main interface to XAI algorithms](#explainer-the-main-interface-to-xai-algorithms) + - [Create Explainer for OpenVINO Model instance](#create-explainer-for-openvino-model-instance) + - [Create Explainer from OpenVINO IR file](#create-explainer-from-openvino-ir-file) + - [Create Explainer from ONNX model file](#create-explainer-from-onnx-model-file) - [Basic usage: Auto mode](#basic-usage-auto-mode) - [Running without `preprocess_fn`](#running-without-preprocess_fn) - [Specifying `preprocess_fn`](#specifying-preprocess_fn) - [White-Box mode](#white-box-mode) - [Black-Box mode](#black-box-mode) - [XAI insertion (white-box usage)](#xai-insertion-white-box-usage) + - [Plot saliency maps](#plot-saliency-maps) - [Example scripts](#example-scripts) @@ -97,12 +101,12 @@ Here's the example how we can avoid passing `preprocess_fn` by preprocessing dat import cv2 import numpy as np import openvino.runtime as ov -from openvino.runtime.utils.data_helpers.wrappers import OVDict +from from typing import Mapping import openvino_xai as xai -def postprocess_fn(x: OVDict): +def postprocess_fn(x: Mapping): # Implementing our own post-process function based on the model's implementation # Return "logits" model output return x["logits"] @@ -143,7 +147,7 @@ explanation.save("output_path", "name") import cv2 import numpy as np import openvino.runtime as ov -from openvino.runtime.utils.data_helpers.wrappers import OVDict +from typing import Mapping import openvino_xai as xai @@ -154,7 +158,7 @@ def preprocess_fn(x: np.ndarray) -> np.ndarray: x = np.expand_dims(x, 0) return x -def postprocess_fn(x: OVDict): +def postprocess_fn(x: Mapping): # Implementing our own post-process function based on the model's implementation # Return "logits" model output return x["logits"] @@ -327,6 +331,65 @@ model_xai = xai.insert_xai( # ***** Downstream task: user's code that infers model_xai and picks 'saliency_map' output ***** ``` +## Plot saliency maps + +To visualize saliency maps, use the `explanation.plot` function. + +The `matplotlib` backend is more convenient for plotting saliency maps in Jupyter notebooks, as it uses the Matplotlib library. By default it generates the grid with 4 images per row (can be agjusted by `num_collumns` parameter). + +The `cv` backend is better for visualization in Python scripts, as it opens extra windows to display the generated saliency maps. + +```python +import cv2 +import numpy as np +import openvino.runtime as ov +import openvino_xai as xai + +def preprocess_fn(image: np.ndarray) -> np.ndarray: + """Preprocess the input image.""" + resized_image = cv2.resize(src=image, dsize=(224, 224)) + expanded_image = np.expand_dims(resized_image, 0) + return expanded_image + +# Create ov.Model +MODEL_PATH = "path/to/model.xml" +model = ov.Core().read_model(MODEL_PATH) # type: ov.Model + +# The Explainer object will prepare and load the model once in the beginning +explainer = xai.Explainer( + model, + task=xai.Task.CLASSIFICATION, + preprocess_fn=preprocess_fn, +) + +voc_labels = [ + 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', + 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' +] + +# Generate and process saliency maps (as many as required, sequentially) +image = cv2.imread("path/to/image.jpg") + +# Run explanation +explanation = explainer( + image, + explain_mode=ExplainMode.WHITEBOX, + label_names=voc_labels, + target_explain_labels=[7, 11], # ['cat', 'dog'] also possible as target classes to explain +) + +# Use matplotlib (recommended for Jupyter) - default backend +explanation.plot() # plot all saliency map +explanation.plot(targets=[7], backend="matplotlib") +explanation.plot(targets=["cat"], backend="matplotlib") +# Plots a grid with 5 images per row +explanation.plot(num_columns=5, backend="matplotlib") + +# Use OpenCV (recommended for Python) - will open new windows with saliency maps +explanation.plot(backend="cv") # plot all saliency map +explanation.plot(targets=[7], backend="cv") +explanation.plot(targets=["cat"], backend="cv") +``` ## Example scripts diff --git a/openvino_xai/explainer/explanation.py b/openvino_xai/explainer/explanation.py index bd63b676..cba77af1 100644 --- a/openvino_xai/explainer/explanation.py +++ b/openvino_xai/explainer/explanation.py @@ -7,8 +7,10 @@ from typing import Dict, List import cv2 +import matplotlib.pyplot as plt import numpy as np +from openvino_xai.common.utils import logger from openvino_xai.explainer.utils import ( convert_targets_to_numpy, explains_all, @@ -149,6 +151,94 @@ def save(self, dir_path: Path | str, name: str | None = None) -> None: image_name = f"{save_name}_target_{target_name}.jpg" if save_name else f"target_{target_name}.jpg" cv2.imwrite(os.path.join(dir_path, image_name), img=map_to_save) + def plot( + self, + targets: np.ndarray | List[int | str] | None = None, + backend: str = "matplotlib", + max_num_plots: int = 24, + num_columns: int = 4, + ) -> None: + """ + Plots saliency maps using the specified backend. + + This function plots available saliency maps using the specified backend. Targets to plot + can be specified by passing a list of target class indices or names. If a provided class is + not available among the saliency maps, it is omitted. + + Args: + targets (np.ndarray | List[int | str] | None): A list or array of target class indices or names to plot. + By default, it's None, and all available saliency maps are plotted. + backend (str): The plotting backend to use. Can be either 'matplotlib' (recommended for Jupyter) + or 'cv' (recommended for Python scripts). Default is 'matplotlib'. + max_num_plots (int): Max number of images to plot. Default is 24 to avoid memory issues. + num_columns (int): Number of columns in the saliency maps visualization grid for the matplotlib backend. + """ + + if targets is None or explains_all(targets): + checked_targets = self.targets + else: + target_indices = get_target_indices(targets, self.label_names) + checked_targets = [] + for target_index in target_indices: + if target_index in self.saliency_map: + checked_targets.append(target_index) + else: + logger.info(f"Provided class index {target_index} is not available among saliency maps.") + + if len(checked_targets) > max_num_plots: + logger.warning( + f"Decrease the number of plotted saliency maps from {len(checked_targets)} to {max_num_plots}" + " to avoid the memory issue. To avoid this, increase the 'max_num_plots' argument." + ) + checked_targets = checked_targets[:max_num_plots] + + if backend == "matplotlib": + self._plot_matplotlib(checked_targets, num_columns) + elif backend == "cv": + self._plot_cv(checked_targets) + else: + raise ValueError(f"Unknown backend {backend}. Use 'matplotlib' or 'cv'.") + + def _plot_matplotlib(self, checked_targets: list[int | str], num_cols: int) -> None: + """Plots saliency maps using matplotlib.""" + num_rows = int(np.ceil(len(checked_targets) / num_cols)) + _, axes = plt.subplots(num_rows, num_cols, figsize=(5 * num_cols, 6 * num_rows)) + axes = axes.flatten() + + for i, target_index in enumerate(checked_targets): + if self.label_names and isinstance(target_index, np.int64): + label_name = f"{self.label_names[target_index]} ({target_index})" + else: + label_name = str(target_index) + + map_to_plot = self.saliency_map[target_index] + + axes[i].imshow(map_to_plot) + axes[i].axis("off") # Hide the axis + axes[i].set_title(f"Class {label_name}") + + # Hide remaining axes + for ax in axes[len(checked_targets) :]: + ax.set_visible(False) + + plt.tight_layout() + plt.show() + + def _plot_cv(self, checked_targets: list[int | str]) -> None: + """Plots saliency maps using OpenCV.""" + for target_index in checked_targets: + if self.label_names and isinstance(target_index, np.int64): + label_name = f"{self.label_names[target_index]} ({target_index})" + else: + label_name = str(target_index) + + map_to_plot = self.saliency_map[target_index] + map_to_plot = cv2.cvtColor(map_to_plot, cv2.COLOR_BGR2RGB) + + cv2.imshow(f"Class {label_name}", map_to_plot) + cv2.waitKey(0) + cv2.destroyAllWindows() + class Layout(Enum): """ diff --git a/pyproject.toml b/pyproject.toml index 204e507a..7dbfc561 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "scipy", "numpy==1.*", "tqdm", + "matplotlib", ] requires-python = ">=3.10" authors = [ diff --git a/tests/unit/explanation/test_explanation.py b/tests/unit/explanation/test_explanation.py index 6975b5e2..d4d519ff 100644 --- a/tests/unit/explanation/test_explanation.py +++ b/tests/unit/explanation/test_explanation.py @@ -1,9 +1,11 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import logging import os import numpy as np +import pytest from openvino_xai.explainer.explanation import Explanation from tests.unit.explanation.test_explanation_utils import VOC_NAMES @@ -67,3 +69,39 @@ def _get_explanation(self, saliency_maps=SALIENCY_MAPS, label_names=VOC_NAMES): label_names=label_names, ) return explanation + + def test_plot(self, mocker, caplog): + explanation = self._get_explanation() + + # Invalid backend + with pytest.raises(ValueError): + explanation.plot(backend="invalid") + + # Plot all saliency maps + explanation.plot() + # Matplotloib backend + explanation.plot([0, 2], backend="matplotlib") + # Targets as label names + explanation.plot(["aeroplane", "bird"], backend="matplotlib") + # Plot all saliency maps + explanation.plot(-1, backend="matplotlib") + # Update the num columns for the matplotlib visualization grid + explanation.plot(backend="matplotlib", num_columns=1) + + # Class index that is not in saliency maps will be ommitted with message + with caplog.at_level(logging.INFO): + explanation.plot([0, 3], backend="matplotlib") + assert "Provided class index 3 is not available among saliency maps." in caplog.text + + # Check threshold + with caplog.at_level(logging.WARNING): + explanation.plot([0, 2], backend="matplotlib", max_num_plots=1) + + # CV backend + mocker.patch("cv2.imshow") + mocker.patch("cv2.waitKey") + explanation.plot([0, 2], backend="cv") + + # Plot activation map + explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_IMAGE, label_names=None) + explanation.plot() From ccfadb9adb500acbd9c360be5e41b7ca91bc9880 Mon Sep 17 00:00:00 2001 From: Galina Zalesskaya Date: Fri, 9 Aug 2024 17:12:49 +0300 Subject: [PATCH 3/3] Make flexible naming in explanation.save (#51) * Update explanation.save * Add documentation * Fix comments * Rename components * 3 -> 2 arguments --- CHANGELOG.md | 4 +- docs/source/user-guide.md | 106 +++++++++++++++++++-- openvino_xai/explainer/explanation.py | 61 +++++++++--- openvino_xai/methods/black_box/aise.py | 7 +- openvino_xai/methods/black_box/rise.py | 2 +- tests/unit/explanation/test_explanation.py | 34 +++++-- 6 files changed, 180 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2907f2d2..f6c33f2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,10 @@ ### Summary * Support OpenVINO IR (.xml) / ONNX (.onnx) model file for `Explainer` model -* Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models. +* Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models * Upgrade OpenVINO to 2024.3.0 * Add saliency map visualization with explanation.plot() +* Enable flexible naming for saved saliency maps and include confidence scores ### What's Changed @@ -19,6 +20,7 @@ * Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models by @negvet in https://github.com/openvinotoolkit/openvino_xai/pull/49 * Upgrade OpenVINO to 2024.3.0 by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/52 * Add saliency map visualization with explanation.plot() by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/53 +* Enable flexible naming for saved saliency maps and include confidence scores by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/51 ### Known Issues diff --git a/docs/source/user-guide.md b/docs/source/user-guide.md index 23db8cf4..8916e046 100644 --- a/docs/source/user-guide.md +++ b/docs/source/user-guide.md @@ -23,6 +23,7 @@ Content: - [Black-Box mode](#black-box-mode) - [XAI insertion (white-box usage)](#xai-insertion-white-box-usage) - [Plot saliency maps](#plot-saliency-maps) + - [Saving saliency maps](#saving-saliency-maps) - [Example scripts](#example-scripts) @@ -100,8 +101,8 @@ Here's the example how we can avoid passing `preprocess_fn` by preprocessing dat ```python import cv2 import numpy as np +from typing import Mapping import openvino.runtime as ov -from from typing import Mapping import openvino_xai as xai @@ -137,7 +138,7 @@ explanation = explainer( ) # Save saliency maps -explanation.save("output_path", "name") +explanation.save("output_path", "name_") ``` ### Specifying `preprocess_fn` @@ -146,8 +147,8 @@ explanation.save("output_path", "name") ```python import cv2 import numpy as np -import openvino.runtime as ov from typing import Mapping +import openvino.runtime as ov import openvino_xai as xai @@ -184,7 +185,7 @@ explanation = explainer( ) # Save saliency maps -explanation.save("output_path", "name") +explanation.save("output_path", "name_") ``` @@ -242,7 +243,7 @@ explanation = explainer( ) # Save saliency maps -explanation.save("output_path", "name") +explanation.save("output_path", "name_") ``` @@ -298,7 +299,7 @@ explanation = explainer( ) # Save saliency maps -explanation.save("output_path", "name") +explanation.save("output_path", "name_") ``` @@ -343,7 +344,9 @@ The `cv` backend is better for visualization in Python scripts, as it opens extr import cv2 import numpy as np import openvino.runtime as ov + import openvino_xai as xai +from openvino_xai.explainer import ExplainMode def preprocess_fn(image: np.ndarray) -> np.ndarray: """Preprocess the input image.""" @@ -391,6 +394,97 @@ explanation.plot(targets=[7], backend="cv") explanation.plot(targets=["cat"], backend="cv") ``` +## Saving saliency maps + +You can easily save saliency maps with flexible naming options by using a `prefix` and `postfix`. The `prefix` allows saliency maps from the same image to have consistent naming. + +The format for naming is: + +`{prefix} + target_id + {postfix}.jpg` + +Additionally, you can include the confidence score for each class in the saved saliency map's name. + +`{prefix} + target_id + {postfix} + confidence.jpg` + +```python +import cv2 +import numpy as np +import openvino.runtime as ov +from typing import Mapping + +import openvino_xai as xai +from openvino_xai.explainer import ExplainMode + +def preprocess_fn(image: np.ndarray) -> np.ndarray: + """Preprocess the input image.""" + x = cv2.resize(src=image, dsize=(224, 224)) + x = x.transpose((2, 0, 1)) + processed_image = np.expand_dims(x, 0) + return processed_image + +def postprocess_fn(output: Mapping): + """Postprocess the model output.""" + output = softmax(output) + return output[0] + +def softmax(x: np.ndarray) -> np.ndarray: + """Compute softmax values of x.""" + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + +# Generate and process saliency maps (as many as required, sequentially) +image = cv2.imread("path/to/image.jpg") + +# Create ov.Model +MODEL_PATH = "path/to/model.xml" +model = ov.Core().read_model(MODEL_PATH) # type: ov.Model + +# The Explainer object will prepare and load the model once in the beginning +explainer = xai.Explainer( + model, + task=xai.Task.CLASSIFICATION, + preprocess_fn=preprocess_fn, +) + +voc_labels = [ + 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', + 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' +] + +# Get predicted confidences for the image +compiled_model = core.compile_model(model=model, device_name="AUTO") +logits = compiled_model(preprocess_fn(image))[0] +result_infer = postprocess_fn(logits) + +# Generate list of predicted class indices and scores +result_idxs = np.argwhere(result_infer > 0.4).flatten() +result_scores = result_infer[result_idxs] + +# Generate dict {class_index: confidence} to save saliency maps +scores_dict = {i: score for i, score in zip(result_idxs, result_scores)} + +# Run explanation +explanation = explainer( + image, + explain_mode=ExplainMode.WHITEBOX, + label_names=voc_labels, + target_explain_labels=result_idxs, # target classes to explain +) + +# Save saliency maps flexibly +OUTPUT_PATH = "output_path" +explanation.save(OUTPUT_PATH) # aeroplane.jpg +explanation.save(OUTPUT_PATH, "image_name_target_") # image_name_target_aeroplane.jpg +explanation.save(OUTPUT_PATH, prefix="image_name_target_") # image_name_target_aeroplane.jpg +explanation.save(OUTPUT_PATH, postfix="_class_map") # aeroplane_class_map.jpg +explanation.save(OUTPUT_PATH, prefix="image_name_", postfix="_class_map") # image_name_aeroplane_class_map.jpg + +# Save saliency maps with confidence scores +explanation.save( + OUTPUT_PATH, prefix="image_name_", postfix="_conf_", confidence_scores=scores_dict +) # image_name_aeroplane_conf_0.85.jpg +``` + ## Example scripts More usage scenarios that can be used with your own models and images as arguments are available in [examples](../../examples). diff --git a/openvino_xai/explainer/explanation.py b/openvino_xai/explainer/explanation.py index cba77af1..e4b5f929 100644 --- a/openvino_xai/explainer/explanation.py +++ b/openvino_xai/explainer/explanation.py @@ -134,21 +134,58 @@ def _select_target_indices( raise ValueError("Provided targer index {targer_index} is not available among saliency maps.") return target_indices - def save(self, dir_path: Path | str, name: str | None = None) -> None: - """Dumps saliency map.""" + def save( + self, + dir_path: Path | str, + prefix: str = "", + postfix: str = "", + confidence_scores: Dict[int, float] | None = None, + ) -> None: + """ + Dumps saliency map images to the specified directory. + + Allows flexibly name the files with the prefix and postfix. + {prefix} + target_id + {postfix}.jpg + + Also allows to add confidence scores to the file names. + {prefix} + target_id + {postfix} + confidence.jpg + + save(output_dir) -> aeroplane.jpg + save(output_dir, prefix="image_name_target_") -> image_name_target_aeroplane.jpg + save(output_dir, postfix="_class_map") -> aeroplane_class_map.jpg + save( + output_dir, prefix="image_name_", postfix="_conf_", confidence_scores=scores + ) -> image_name_aeroplane_conf_0.85.jpg + + Parameters: + :param dir_path: The directory path where the saliency maps will be saved. + :type dir_path: Path | str + :param prefix: Optional prefix for the saliency map names. Default is an empty string. + :type prefix: str + :param postfix: Optional postfix for the saliency map names. Default is an empty string. + :type postfix: str + :param confidence_scores: Dict with confidence scores for each class index. Default is None. + :type confidence_scores: Dict[int, float] | None + + """ + os.makedirs(dir_path, exist_ok=True) - save_name = name if name else "" - for cls_idx, map_to_save in self._saliency_map.items(): + + template = f"{prefix}{{target_name}}{postfix}{{conf_score}}.jpg" + for target_idx, map_to_save in self._saliency_map.items(): + conf_score = "" map_to_save = cv2.cvtColor(map_to_save, code=cv2.COLOR_RGB2BGR) - if isinstance(cls_idx, str): - cv2.imwrite(os.path.join(dir_path, f"{save_name}.jpg"), img=map_to_save) - return + if isinstance(target_idx, str): + target_name = "activation_map" + elif self.label_names and isinstance(target_idx, np.int64): + target_name = self.label_names[target_idx] else: - if self.label_names: - target_name = self.label_names[cls_idx] - else: - target_name = str(cls_idx) - image_name = f"{save_name}_target_{target_name}.jpg" if save_name else f"target_{target_name}.jpg" + target_name = str(target_idx) + + if confidence_scores and target_idx in confidence_scores: + conf_score = f"{confidence_scores[int(target_idx)]:.2f}" + + image_name = template.format(target_name=target_name, conf_score=conf_score) cv2.imwrite(os.path.join(dir_path, image_name), img=map_to_save) def plot( diff --git a/openvino_xai/methods/black_box/aise.py b/openvino_xai/methods/black_box/aise.py index 722e9547..b2e27324 100644 --- a/openvino_xai/methods/black_box/aise.py +++ b/openvino_xai/methods/black_box/aise.py @@ -3,11 +3,10 @@ import collections import math -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Mapping, Tuple import numpy as np import openvino.runtime as ov -from openvino.runtime.utils.data_helpers.wrappers import OVDict from scipy.optimize import Bounds, direct from openvino_xai.common.utils import ( @@ -28,7 +27,7 @@ class AISE(BlackBoxXAIMethod): :param model: OpenVINO model. :type model: ov.Model :param postprocess_fn: Post-processing function that extract scores from IR model output. - :type postprocess_fn: Callable[[OVDict], np.ndarray] + :type postprocess_fn: Callable[[Mapping], np.ndarray] :param preprocess_fn: Pre-processing function, identity function by default (assume input images are already preprocessed by user). :type preprocess_fn: Callable[[np.ndarray], np.ndarray] @@ -41,7 +40,7 @@ class AISE(BlackBoxXAIMethod): def __init__( self, model: ov.Model, - postprocess_fn: Callable[[OVDict], np.ndarray], + postprocess_fn: Callable[[Mapping], np.ndarray], preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), device_name: str = "CPU", prepare_model: bool = True, diff --git a/openvino_xai/methods/black_box/rise.py b/openvino_xai/methods/black_box/rise.py index afb3beaa..143b03b7 100644 --- a/openvino_xai/methods/black_box/rise.py +++ b/openvino_xai/methods/black_box/rise.py @@ -20,7 +20,7 @@ class RISE(BlackBoxXAIMethod): :param model: OpenVINO model. :type model: ov.Model :param postprocess_fn: Post-processing function that extract scores from IR model output. - :type postprocess_fn: Callable[[OVDict], np.ndarray] + :type postprocess_fn: Callable[[Mapping], np.ndarray] :param preprocess_fn: Pre-processing function, identity function by default (assume input images are already preprocessed by user). :type preprocess_fn: Callable[[np.ndarray], np.ndarray] diff --git a/tests/unit/explanation/test_explanation.py b/tests/unit/explanation/test_explanation.py index d4d519ff..1cf2a791 100644 --- a/tests/unit/explanation/test_explanation.py +++ b/tests/unit/explanation/test_explanation.py @@ -43,23 +43,37 @@ def test_save(self, tmp_path): save_path = tmp_path / "saliency_maps" explanation = self._get_explanation() - explanation.save(save_path, "test_map") - assert os.path.isfile(save_path / "test_map_target_aeroplane.jpg") - assert os.path.isfile(save_path / "test_map_target_bird.jpg") + explanation.save(save_path, prefix="image_name_") + assert os.path.isfile(save_path / "image_name_aeroplane.jpg") + assert os.path.isfile(save_path / "image_name_bird.jpg") explanation = self._get_explanation() explanation.save(save_path) - assert os.path.isfile(save_path / "target_aeroplane.jpg") - assert os.path.isfile(save_path / "target_bird.jpg") + assert os.path.isfile(save_path / "aeroplane.jpg") + assert os.path.isfile(save_path / "bird.jpg") explanation = self._get_explanation(label_names=None) - explanation.save(save_path, "test_map") - assert os.path.isfile(save_path / "test_map_target_0.jpg") - assert os.path.isfile(save_path / "test_map_target_2.jpg") + explanation.save(save_path, postfix="_class_map") + assert os.path.isfile(save_path / "0_class_map.jpg") + assert os.path.isfile(save_path / "2_class_map.jpg") + + explanation = self._get_explanation() + explanation.save(save_path, prefix="image_name_", postfix="_map") + assert os.path.isfile(save_path / "image_name_aeroplane_map.jpg") + assert os.path.isfile(save_path / "image_name_bird_map.jpg") + + explanation = self._get_explanation() + explanation.save(save_path, postfix="_conf_", confidence_scores={0: 0.92, 2: 0.85}) + assert os.path.isfile(save_path / "aeroplane_conf_0.92.jpg") + assert os.path.isfile(save_path / "bird_conf_0.85.jpg") + + explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_IMAGE, label_names=None) + explanation.save(save_path, prefix="test_map_") + assert os.path.isfile(save_path / "test_map_activation_map.jpg") explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_IMAGE, label_names=None) - explanation.save(save_path, "test_map") - assert os.path.isfile(save_path / "test_map.jpg") + explanation.save(save_path, prefix="test_map_", postfix="_result") + assert os.path.isfile(save_path / "test_map_activation_map_result.jpg") def _get_explanation(self, saliency_maps=SALIENCY_MAPS, label_names=VOC_NAMES): explain_targets = [0, 2]