Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support per-target Prediction #62

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/run_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def explain_auto(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_auto"
explanation.save(output, Path(args.image_path).stem)
explanation.save(output, f"{Path(args.image_path).stem}_")


def explain_white_box(args):
Expand Down Expand Up @@ -117,7 +117,7 @@ def explain_white_box(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_white_box"
explanation.save(output, Path(args.image_path).stem)
explanation.save(output, f"{Path(args.image_path).stem}_")


def explain_black_box(args):
Expand Down Expand Up @@ -160,7 +160,7 @@ def explain_black_box(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_black_box"
explanation.save(output, Path(args.image_path).stem)
explanation.save(output, f"{Path(args.image_path).stem}_")


def explain_white_box_multiple_images(args):
Expand Down Expand Up @@ -203,7 +203,7 @@ def explain_white_box_multiple_images(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_white_box_multiple_images"
explanation[0].save(output, Path(args.image_path).stem)
explanation[0].save(output, f"{Path(args.image_path).stem}_")


def explain_white_box_vit(args):
Expand Down Expand Up @@ -241,7 +241,7 @@ def explain_white_box_vit(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_white_box_vit"
explanation.save(output, Path(args.image_path).stem)
explanation.save(output, f"{Path(args.image_path).stem}_")


def insert_xai(args):
Expand Down
3 changes: 2 additions & 1 deletion openvino_xai/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,9 @@ def explain(
explanation = Explanation(
saliency_map=saliency_map,
targets=targets,
task=self.task,
label_names=label_names,
metadata=self.method.metadata,
predictions=self.method.predictions,
)
return self._visualize(
original_input_image,
Expand Down
18 changes: 13 additions & 5 deletions openvino_xai/explainer/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List
from typing import Dict, List

import cv2
import matplotlib.pyplot as plt
Expand All @@ -17,6 +17,7 @@
explains_all,
get_target_indices,
)
from openvino_xai.methods.base import Prediction


class Explanation:
Expand All @@ -28,16 +29,21 @@ class Explanation:
:param targets: List of custom labels to explain, optional. Can be list of integer indices (int),
or list of names (str) from label_names.
:type targets: np.ndarray | List[int | str] | int | str
:param task: Type of the task: CLASSIFICATION or DETECTION.
:type task: Task
:param label_names: List of all label names.
:type label_names: List[str] | None
:param predictions: Per-target model prediction (available only for black-box methods).
:type predictions: Dict[int, Prediction] | None
"""

def __init__(
self,
saliency_map: np.ndarray | Dict[int | str, np.ndarray],
targets: np.ndarray | List[int | str] | int | str,
task: Task,
label_names: List[str] | None = None,
metadata: Dict[Task, Any] | None = None,
predictions: Dict[int, Prediction] | None = None,
):
targets = convert_targets_to_numpy(targets)

Expand All @@ -57,10 +63,12 @@ def __init__(
self.layout = Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY

if not explains_all(targets) and not self.layout == Layout.ONE_MAP_PER_IMAGE_GRAY:
self._saliency_map = self._select_target_saliency_maps(targets, label_names)
label_names_ = None if task == Task.DETECTION else label_names
self._saliency_map = self._select_target_saliency_maps(targets, label_names_)

self.task = task
self.label_names = label_names
self.metadata = metadata
self.predictions = predictions

@property
def saliency_map(self) -> Dict[int | str, np.ndarray]:
Expand Down Expand Up @@ -180,7 +188,7 @@ def save(
map_to_save = cv2.cvtColor(map_to_save, code=cv2.COLOR_RGB2BGR)
if isinstance(target_idx, str):
target_name = "activation_map"
elif self.label_names and isinstance(target_idx, np.int64):
elif self.label_names and isinstance(target_idx, np.int64) and self.task != Task.DETECTION:
target_name = self.label_names[target_idx]
else:
target_name = str(target_idx)
Expand Down
93 changes: 70 additions & 23 deletions openvino_xai/explainer/visualizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Tuple
from typing import Dict, List, Tuple

import cv2
import numpy as np
Expand All @@ -16,6 +16,7 @@
Explanation,
Layout,
)
from openvino_xai.methods.base import Prediction


def resize(saliency_map: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
Expand Down Expand Up @@ -80,6 +81,7 @@
colormap: bool = True,
overlay: bool = False,
overlay_weight: float = 0.5,
overlay_prediction: bool = True,
) -> Explanation:
return self.visualize(
explanation,
Expand All @@ -90,6 +92,7 @@
colormap,
overlay,
overlay_weight,
overlay_prediction,
)

def visualize(
Expand All @@ -102,6 +105,7 @@
colormap: bool = True,
overlay: bool = False,
overlay_weight: float = 0.5,
overlay_prediction: bool = True,
) -> Explanation:
"""
Saliency map postprocess method.
Expand All @@ -126,6 +130,8 @@
:type overlay: bool
:parameter overlay_weight: Weight of the saliency map when overlaying the input data with the saliency map.
:type overlay_weight: float
:parameter overlay_prediction: If True, plot model prediction over the overlay.
:type overlay_prediction: bool
"""
if original_input_image is not None:
original_input_image = format_to_bhwc(original_input_image)
Expand All @@ -147,7 +153,14 @@
saliency_map_np = self._apply_overlay(
explanation, saliency_map_np, original_input_image, output_size, overlay_weight
)
saliency_map_np = self._apply_metadata(explanation.metadata, saliency_map_np, indices_to_return)
if overlay_prediction and explanation.task == Task.CLASSIFICATION:
self._put_classification_info(
saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions # type:ignore
)
if overlay_prediction and explanation.task == Task.DETECTION:
self._put_detection_info(
saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions # type:ignore
)
else:
if resize:
if original_input_image is None and output_size is None:
Expand All @@ -162,27 +175,61 @@
return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, indices_to_return)

@staticmethod
def _apply_metadata(metadata: Dict[Task, Any], saliency_map_np: np.ndarray, indices: List[int | str]):
# TODO (negvet): support when indices are strings
if metadata:
if Task.DETECTION in metadata:
for smap_i, target_index in zip(range(len(saliency_map_np)), indices):
saliency_map = saliency_map_np[smap_i]
box, score, label_index = metadata[Task.DETECTION][target_index]
x1, y1, x2, y2 = box
cv2.rectangle(saliency_map, (int(x1), int(y1)), (int(x2), int(y2)), color=(255, 0, 0), thickness=2)
box_label = f"{label_index}|{score:.2f}"
box_label_loc = int(x1), int(y1 - 5)
cv2.putText(
saliency_map,
box_label,
org=box_label_loc,
fontFace=1,
fontScale=1,
color=(255, 0, 0),
thickness=2,
)
return saliency_map_np
def _put_classification_info(
saliency_map_np: np.ndarray,
indices: List[int],
label_names: List[str] | None,
predictions: Dict[int, Prediction] | None,
) -> None:
corner_location = 3, 17
for smap, target_index in zip(range(len(saliency_map_np)), indices):
label = label_names[target_index] if label_names else str(target_index)
if predictions and target_index in predictions:
score = predictions[target_index].score
if score:
label = f"{label}|{score:.2f}"

cv2.putText(
saliency_map_np[smap],
label,
org=corner_location,
fontFace=1,
fontScale=1.3,
color=(255, 0, 0),
thickness=2,
)

@staticmethod
def _put_detection_info(
saliency_map_np: np.ndarray,
indices: List[int],
label_names: List[str] | None,
predictions: Dict[int, Prediction] | None,
) -> None:
if not predictions:
return

Check warning on line 210 in openvino_xai/explainer/visualizer.py

View check run for this annotation

Codecov / codecov/patch

openvino_xai/explainer/visualizer.py#L210

Added line #L210 was not covered by tests

for smap, target_index in zip(range(len(saliency_map_np)), indices):
saliency_map = saliency_map_np[smap]
label_index = predictions[target_index].label
score = predictions[target_index].score
box = predictions[target_index].bounding_box

x1, y1, x2, y2 = np.array(box, dtype=np.int32)
cv2.rectangle(saliency_map, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)

label = label_names[label_index] if label_names else label_index
label_score = f"{label}|{score:.2f}"
box_location = int(x1), int(y1 - 5)
cv2.putText(
saliency_map,
label_score,
org=box_location,
fontFace=1,
fontScale=1.3,
color=(255, 0, 0),
thickness=2,
)

@staticmethod
def _apply_scaling(explanation: Explanation, saliency_map_np: np.ndarray) -> np.ndarray:
Expand Down
14 changes: 10 additions & 4 deletions openvino_xai/methods/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import collections
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Mapping
from dataclasses import dataclass
from typing import Callable, Dict, List, Mapping, Tuple

import numpy as np
import openvino as ov

from openvino_xai.common.parameters import Task
from openvino_xai.common.utils import IdentityPreprocessFN


Expand All @@ -25,7 +24,7 @@ def __init__(
self._model_compiled = None
self.preprocess_fn = preprocess_fn
self._device_name = device_name
self.metadata: Dict[Task, Any] = collections.defaultdict(dict)
self.predictions: Dict[int, Prediction] = {}

@property
def model_compiled(self) -> ov.CompiledModel | None:
Expand All @@ -50,3 +49,10 @@ def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np.
def load_model(self) -> None:
core = ov.Core()
self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name)


@dataclass
class Prediction:
label: int | None = None
score: float | None = None
bounding_box: List | Tuple | None = None
5 changes: 3 additions & 2 deletions openvino_xai/methods/black_box/aise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def __init__(
device_name: str = "CPU",
prepare_model: bool = True,
):
super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name)
self.postprocess_fn = postprocess_fn
super().__init__(
model=model, postprocess_fn=postprocess_fn, preprocess_fn=preprocess_fn, device_name=device_name
)

self.data_preprocessed = None
self.target: int | None = None
Expand Down
13 changes: 10 additions & 3 deletions openvino_xai/methods/black_box/aise/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
scaling,
sigmoid,
)
from openvino_xai.methods.base import Prediction
from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.utils import check_classification_output
Expand Down Expand Up @@ -91,11 +92,12 @@ def generate_saliency_map( # type: ignore
"""
self.data_preprocessed = self.preprocess_fn(data)

logits = self.get_logits(self.data_preprocessed)
if target_indices is None:
num_classes = self.get_num_classes(self.data_preprocessed)
if num_classes > 10:
logger.info(f"num_classes = {num_classes}, which might take significant time to process.")
num_classes = logits.shape[1]
target_indices = list(range(num_classes))
if len(target_indices) > 10:
logger.info(f"{len(target_indices)} targets to process, which might take significant time.")

self.num_iterations_per_kernel, self.kernel_widths = self._preset_parameters(
preset,
Expand All @@ -110,6 +112,7 @@ def generate_saliency_map( # type: ignore
self._mask_generator = GaussianPerturbationMask(self.input_size)

saliency_maps = {}
self.predictions = {}
for target in target_indices:
self.kernel_params_hist = collections.defaultdict(list)
self.pred_score_hist = collections.defaultdict(list)
Expand All @@ -119,6 +122,10 @@ def generate_saliency_map( # type: ignore
if scale_output:
saliency_map_per_target = scaling(saliency_map_per_target)
saliency_maps[target] = saliency_map_per_target
self.predictions[target] = Prediction(
label=target,
score=logits[0][target],
)
return saliency_maps

@staticmethod
Expand Down
Loading