diff --git a/openvino_xai/explainer/explainer.py b/openvino_xai/explainer/explainer.py index 31c08afd..a27fc74b 100644 --- a/openvino_xai/explainer/explainer.py +++ b/openvino_xai/explainer/explainer.py @@ -221,8 +221,9 @@ def explain( explanation = Explanation( saliency_map=saliency_map, targets=targets, + task=self.task, label_names=label_names, - metadata=self.method.metadata, + predictions=self.method.predictions, ) return self._visualize( original_input_image, diff --git a/openvino_xai/explainer/explanation.py b/openvino_xai/explainer/explanation.py index 13f61ef5..d071c369 100644 --- a/openvino_xai/explainer/explanation.py +++ b/openvino_xai/explainer/explanation.py @@ -17,6 +17,7 @@ explains_all, get_target_indices, ) +from openvino_xai.methods.base import Prediction class Explanation: @@ -36,8 +37,9 @@ def __init__( self, saliency_map: np.ndarray | Dict[int | str, np.ndarray], targets: np.ndarray | List[int | str] | int | str, + task: Task, label_names: List[str] | None = None, - metadata: Dict[Task, Any] | None = None, + predictions: Dict[Task, Prediction] | None = None, ): targets = convert_targets_to_numpy(targets) @@ -57,10 +59,14 @@ def __init__( self.layout = Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY if not explains_all(targets) and not self.layout == Layout.ONE_MAP_PER_IMAGE_GRAY: - self._saliency_map = self._select_target_saliency_maps(targets, label_names) + if task == Task.DETECTION: + self._saliency_map = self._select_target_saliency_maps(targets, None) + else: + self._saliency_map = self._select_target_saliency_maps(targets, label_names) + self.task = task self.label_names = label_names - self.metadata = metadata + self.predictions = predictions @property def saliency_map(self) -> Dict[int | str, np.ndarray]: diff --git a/openvino_xai/explainer/visualizer.py b/openvino_xai/explainer/visualizer.py index aedd8295..ade91ae9 100644 --- a/openvino_xai/explainer/visualizer.py +++ b/openvino_xai/explainer/visualizer.py @@ -16,6 +16,7 @@ Explanation, Layout, ) +from openvino_xai.methods.base import Prediction def resize(saliency_map: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: @@ -147,7 +148,10 @@ def visualize( saliency_map_np = self._apply_overlay( explanation, saliency_map_np, original_input_image, output_size, overlay_weight ) - saliency_map_np = self._apply_metadata(explanation.metadata, saliency_map_np, indices_to_return) + if explanation.task == Task.CLASSIFICATION: + self._put_classification_info(saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions) + if explanation.task == Task.DETECTION and explanation.predictions: + self._put_detection_info(saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions) else: if resize: if original_input_image is None and output_size is None: @@ -162,27 +166,48 @@ def visualize( return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, indices_to_return) @staticmethod - def _apply_metadata(metadata: Dict[Task, Any], saliency_map_np: np.ndarray, indices: List[int | str]): - # TODO (negvet): support when indices are strings - if metadata: - if Task.DETECTION in metadata: - for smap_i, target_index in zip(range(len(saliency_map_np)), indices): - saliency_map = saliency_map_np[smap_i] - box, score, label_index = metadata[Task.DETECTION][target_index] - x1, y1, x2, y2 = box - cv2.rectangle(saliency_map, (int(x1), int(y1)), (int(x2), int(y2)), color=(255, 0, 0), thickness=2) - box_label = f"{label_index}|{score:.2f}" - box_label_loc = int(x1), int(y1 - 5) - cv2.putText( - saliency_map, - box_label, - org=box_label_loc, - fontFace=1, - fontScale=1, - color=(255, 0, 0), - thickness=2, - ) - return saliency_map_np + def _put_classification_info(saliency_map_np: np.ndarray, indices: List[int | str], label_names: List[str] | None, predictions: Dict[int, Prediction]) -> None: + for smap, target_index in zip(range(len(saliency_map_np)), indices): + corner_location = 3, 17 + label = label_names[target_index] if label_names else str(target_index) + if predictions and target_index in predictions: + score = predictions[target_index].score + if score: + label = f"{label}|{score:.2f}" + + cv2.putText( + saliency_map_np[smap], + label, + org=corner_location, + fontFace=1, + fontScale=1.3, + color=(255, 0, 0), + thickness=2, + ) + + @staticmethod + def _put_detection_info(saliency_map_np: np.ndarray, indices: List[int | str], label_names: List[str] | None, predictions: Dict[int, Prediction]) -> None: + for smap, target_index in zip(range(len(saliency_map_np)), indices): + saliency_map = saliency_map_np[smap] + label_index = predictions[target_index].label + score = predictions[target_index].score + box = predictions[target_index].bounding_box + + x1, y1, x2, y2 = box + cv2.rectangle(saliency_map, (int(x1), int(y1)), (int(x2), int(y2)), color=(255, 0, 0), thickness=2) + + label = label_names[label_index] if label_names else label_index + label_score = f"{label}|{score:.2f}" + box_location = int(x1), int(y1 - 5) + cv2.putText( + saliency_map, + label_score, + org=box_location, + fontFace=1, + fontScale=1.3, + color=(255, 0, 0), + thickness=2, + ) @staticmethod def _apply_scaling(explanation: Explanation, saliency_map_np: np.ndarray) -> np.ndarray: diff --git a/openvino_xai/methods/base.py b/openvino_xai/methods/base.py index 59d207db..b4e236f3 100644 --- a/openvino_xai/methods/base.py +++ b/openvino_xai/methods/base.py @@ -3,7 +3,8 @@ import collections from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Mapping +from dataclasses import dataclass +from typing import Any, Callable, Dict, Mapping, Tuple import numpy as np import openvino as ov @@ -25,7 +26,7 @@ def __init__( self._model_compiled = None self.preprocess_fn = preprocess_fn self._device_name = device_name - self.metadata: Dict[Task, Any] = collections.defaultdict(dict) + self.predictions = {} @property def model_compiled(self) -> ov.CompiledModel | None: @@ -50,3 +51,10 @@ def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np. def load_model(self) -> None: core = ov.Core() self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name) + + +@dataclass +class Prediction: + label: int | None = None + score: float | None = None + bounding_box: Tuple | None = None diff --git a/openvino_xai/methods/black_box/aise/classification.py b/openvino_xai/methods/black_box/aise/classification.py index a4f3e340..f9b38f2e 100644 --- a/openvino_xai/methods/black_box/aise/classification.py +++ b/openvino_xai/methods/black_box/aise/classification.py @@ -16,6 +16,7 @@ scaling, sigmoid, ) +from openvino_xai.methods.base import Prediction from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask from openvino_xai.methods.black_box.base import Preset from openvino_xai.methods.black_box.utils import check_classification_output @@ -91,11 +92,12 @@ def generate_saliency_map( # type: ignore """ self.data_preprocessed = self.preprocess_fn(data) + logits = self.get_logits(self.data_preprocessed) if target_indices is None: - num_classes = self.get_num_classes(self.data_preprocessed) - if num_classes > 10: - logger.info(f"num_classes = {num_classes}, which might take significant time to process.") + num_classes = logits.shape[1] target_indices = list(range(num_classes)) + if len(target_indices) > 10: + logger.info(f"{len(target_indices)} targets to process, which might take significant time.") self.num_iterations_per_kernel, self.kernel_widths = self._preset_parameters( preset, @@ -110,6 +112,7 @@ def generate_saliency_map( # type: ignore self._mask_generator = GaussianPerturbationMask(self.input_size) saliency_maps = {} + self.predictions = {} for target in target_indices: self.kernel_params_hist = collections.defaultdict(list) self.pred_score_hist = collections.defaultdict(list) @@ -119,6 +122,10 @@ def generate_saliency_map( # type: ignore if scale_output: saliency_map_per_target = scaling(saliency_map_per_target) saliency_maps[target] = saliency_map_per_target + self.predictions[target] = Prediction( + label=target, + score=logits[0][target], + ) return saliency_maps @staticmethod diff --git a/openvino_xai/methods/black_box/aise/detection.py b/openvino_xai/methods/black_box/aise/detection.py index ae75f6e7..bbc621be 100644 --- a/openvino_xai/methods/black_box/aise/detection.py +++ b/openvino_xai/methods/black_box/aise/detection.py @@ -16,6 +16,7 @@ logger, scaling, ) +from openvino_xai.methods.base import Prediction from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask from openvino_xai.methods.black_box.base import Preset from openvino_xai.methods.black_box.utils import check_detection_output @@ -56,6 +57,7 @@ def __init__( prepare_model=prepare_model, ) self.deletion = False + self.predictions = {} def generate_saliency_map( # type: ignore self, @@ -120,7 +122,7 @@ def generate_saliency_map( # type: ignore self._mask_generator = GaussianPerturbationMask(self.input_size) saliency_maps = {} - self.metadata: Dict[Task, Any] = collections.defaultdict(dict) + self.predictions = {} for target in target_indices: self.target_box = boxes[target] self.target_label = labels[target] @@ -137,7 +139,7 @@ def generate_saliency_map( # type: ignore saliency_map_per_target = scaling(saliency_map_per_target) saliency_maps[target] = saliency_map_per_target - self._update_metadata(boxes, scores, labels, target, original_size) + self._update_predictions(boxes, scores, labels, target, original_size) return saliency_maps @staticmethod @@ -205,7 +207,7 @@ def _iou(box1: np.ndarray | List[float], box2: np.ndarray | List[float]) -> floa area2 = np.prod(box2[2:] - box2[:2]) return intersection / (area1 + area2 - intersection) - def _update_metadata( + def _update_predictions( self, boxes: np.ndarray | List, scores: np.ndarray | List[float], @@ -218,4 +220,8 @@ def _update_metadata( height_scale = original_size[0] / self.input_size[0] x1, x2 = x1 * width_scale, x2 * width_scale y1, y2 = y1 * height_scale, y2 * height_scale - self.metadata[Task.DETECTION][target] = [x1, y1, x2, y2], scores[target], labels[target] + self.predictions[target] = Prediction( + label=labels[target], + score=scores[target], + bounding_box=[x1, y1, x2, y2], + ) diff --git a/openvino_xai/methods/black_box/base.py b/openvino_xai/methods/black_box/base.py index 12302218..36e17be0 100644 --- a/openvino_xai/methods/black_box/base.py +++ b/openvino_xai/methods/black_box/base.py @@ -3,6 +3,7 @@ from enum import Enum +import numpy as np import openvino.runtime as ov from openvino_xai.methods.base import MethodBase @@ -18,12 +19,12 @@ def prepare_model(self, load_model: bool = True) -> ov.Model: self.load_model() return self._model - def get_num_classes(self, data_preprocessed): - """Estimates number of classes for the classification model. Expects batch dimention.""" + def get_logits(self, data_preprocessed: np.ndarray) -> np.ndarray: + """Gets logits for the classification model. Expects batch dimention.""" forward_output = self.model_forward(data_preprocessed, preprocess=False) logits = self.postprocess_fn(forward_output) check_classification_output(logits) - return logits.shape[1] + return logits class Preset(Enum): diff --git a/openvino_xai/methods/black_box/rise.py b/openvino_xai/methods/black_box/rise.py index dec17423..85668d1e 100644 --- a/openvino_xai/methods/black_box/rise.py +++ b/openvino_xai/methods/black_box/rise.py @@ -9,6 +9,7 @@ from tqdm import tqdm from openvino_xai.common.utils import IdentityPreprocessFN, is_bhwc_layout, scaling +from openvino_xai.methods.base import Prediction from openvino_xai.methods.black_box.base import BlackBoxXAIMethod, Preset from openvino_xai.methods.black_box.utils import check_classification_output @@ -132,9 +133,10 @@ def _run_synchronous_explanation( ) -> np.ndarray: input_size = data_preprocessed.shape[1:3] if is_bhwc_layout(data_preprocessed) else data_preprocessed.shape[2:4] - num_classes = self.get_num_classes(data_preprocessed) + logits = self.get_logits(data_preprocessed) if target_classes is None: + num_classes = logits.shape[1] num_targets = num_classes else: num_targets = len(target_classes) @@ -159,6 +161,11 @@ def _run_synchronous_explanation( if target_classes is not None: saliency_maps = self._reformat_as_dict(saliency_maps, target_classes) + for target in target_classes: + self.predictions[target] = Prediction( + label=target, + score=logits[0][target], + ) return saliency_maps @staticmethod