Skip to content

Commit

Permalink
Support per-target Prediction (#62)
Browse files Browse the repository at this point in the history
* Support prediction info overlay

* misc

* tests + run_cls
  • Loading branch information
negvet authored Sep 4, 2024
1 parent cf77ec4 commit f637a5f
Show file tree
Hide file tree
Showing 17 changed files with 187 additions and 65 deletions.
10 changes: 5 additions & 5 deletions examples/run_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def explain_auto(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_auto"
explanation.save(output, Path(args.image_path).stem)
explanation.save(output, f"{Path(args.image_path).stem}_")


def explain_white_box(args):
Expand Down Expand Up @@ -117,7 +117,7 @@ def explain_white_box(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_white_box"
explanation.save(output, Path(args.image_path).stem)
explanation.save(output, f"{Path(args.image_path).stem}_")


def explain_black_box(args):
Expand Down Expand Up @@ -160,7 +160,7 @@ def explain_black_box(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_black_box"
explanation.save(output, Path(args.image_path).stem)
explanation.save(output, f"{Path(args.image_path).stem}_")


def explain_white_box_multiple_images(args):
Expand Down Expand Up @@ -203,7 +203,7 @@ def explain_white_box_multiple_images(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_white_box_multiple_images"
explanation[0].save(output, Path(args.image_path).stem)
explanation[0].save(output, f"{Path(args.image_path).stem}_")


def explain_white_box_vit(args):
Expand Down Expand Up @@ -241,7 +241,7 @@ def explain_white_box_vit(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "explain_white_box_vit"
explanation.save(output, Path(args.image_path).stem)
explanation.save(output, f"{Path(args.image_path).stem}_")


def insert_xai(args):
Expand Down
3 changes: 2 additions & 1 deletion openvino_xai/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,9 @@ def explain(
explanation = Explanation(
saliency_map=saliency_map,
targets=targets,
task=self.task,
label_names=label_names,
metadata=self.method.metadata,
predictions=self.method.predictions,
)
return self._visualize(
original_input_image,
Expand Down
18 changes: 13 additions & 5 deletions openvino_xai/explainer/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List
from typing import Dict, List

import cv2
import matplotlib.pyplot as plt
Expand All @@ -17,6 +17,7 @@
explains_all,
get_target_indices,
)
from openvino_xai.methods.base import Prediction


class Explanation:
Expand All @@ -28,16 +29,21 @@ class Explanation:
:param targets: List of custom labels to explain, optional. Can be list of integer indices (int),
or list of names (str) from label_names.
:type targets: np.ndarray | List[int | str] | int | str
:param task: Type of the task: CLASSIFICATION or DETECTION.
:type task: Task
:param label_names: List of all label names.
:type label_names: List[str] | None
:param predictions: Per-target model prediction (available only for black-box methods).
:type predictions: Dict[int, Prediction] | None
"""

def __init__(
self,
saliency_map: np.ndarray | Dict[int | str, np.ndarray],
targets: np.ndarray | List[int | str] | int | str,
task: Task,
label_names: List[str] | None = None,
metadata: Dict[Task, Any] | None = None,
predictions: Dict[int, Prediction] | None = None,
):
targets = convert_targets_to_numpy(targets)

Expand All @@ -57,10 +63,12 @@ def __init__(
self.layout = Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY

if not explains_all(targets) and not self.layout == Layout.ONE_MAP_PER_IMAGE_GRAY:
self._saliency_map = self._select_target_saliency_maps(targets, label_names)
label_names_ = None if task == Task.DETECTION else label_names
self._saliency_map = self._select_target_saliency_maps(targets, label_names_)

self.task = task
self.label_names = label_names
self.metadata = metadata
self.predictions = predictions

@property
def saliency_map(self) -> Dict[int | str, np.ndarray]:
Expand Down Expand Up @@ -180,7 +188,7 @@ def save(
map_to_save = cv2.cvtColor(map_to_save, code=cv2.COLOR_RGB2BGR)
if isinstance(target_idx, str):
target_name = "activation_map"
elif self.label_names and isinstance(target_idx, np.int64):
elif self.label_names and isinstance(target_idx, np.int64) and self.task != Task.DETECTION:
target_name = self.label_names[target_idx]
else:
target_name = str(target_idx)
Expand Down
93 changes: 70 additions & 23 deletions openvino_xai/explainer/visualizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Tuple
from typing import Dict, List, Tuple

import cv2
import numpy as np
Expand All @@ -16,6 +16,7 @@
Explanation,
Layout,
)
from openvino_xai.methods.base import Prediction


def resize(saliency_map: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
Expand Down Expand Up @@ -80,6 +81,7 @@ def __call__(
colormap: bool = True,
overlay: bool = False,
overlay_weight: float = 0.5,
overlay_prediction: bool = True,
) -> Explanation:
return self.visualize(
explanation,
Expand All @@ -90,6 +92,7 @@ def __call__(
colormap,
overlay,
overlay_weight,
overlay_prediction,
)

def visualize(
Expand All @@ -102,6 +105,7 @@ def visualize(
colormap: bool = True,
overlay: bool = False,
overlay_weight: float = 0.5,
overlay_prediction: bool = True,
) -> Explanation:
"""
Saliency map postprocess method.
Expand All @@ -126,6 +130,8 @@ def visualize(
:type overlay: bool
:parameter overlay_weight: Weight of the saliency map when overlaying the input data with the saliency map.
:type overlay_weight: float
:parameter overlay_prediction: If True, plot model prediction over the overlay.
:type overlay_prediction: bool
"""
if original_input_image is not None:
original_input_image = format_to_bhwc(original_input_image)
Expand All @@ -147,7 +153,14 @@ def visualize(
saliency_map_np = self._apply_overlay(
explanation, saliency_map_np, original_input_image, output_size, overlay_weight
)
saliency_map_np = self._apply_metadata(explanation.metadata, saliency_map_np, indices_to_return)
if overlay_prediction and explanation.task == Task.CLASSIFICATION:
self._put_classification_info(
saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions # type:ignore
)
if overlay_prediction and explanation.task == Task.DETECTION:
self._put_detection_info(
saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions # type:ignore
)
else:
if resize:
if original_input_image is None and output_size is None:
Expand All @@ -162,27 +175,61 @@ def visualize(
return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, indices_to_return)

@staticmethod
def _apply_metadata(metadata: Dict[Task, Any], saliency_map_np: np.ndarray, indices: List[int | str]):
# TODO (negvet): support when indices are strings
if metadata:
if Task.DETECTION in metadata:
for smap_i, target_index in zip(range(len(saliency_map_np)), indices):
saliency_map = saliency_map_np[smap_i]
box, score, label_index = metadata[Task.DETECTION][target_index]
x1, y1, x2, y2 = box
cv2.rectangle(saliency_map, (int(x1), int(y1)), (int(x2), int(y2)), color=(255, 0, 0), thickness=2)
box_label = f"{label_index}|{score:.2f}"
box_label_loc = int(x1), int(y1 - 5)
cv2.putText(
saliency_map,
box_label,
org=box_label_loc,
fontFace=1,
fontScale=1,
color=(255, 0, 0),
thickness=2,
)
return saliency_map_np
def _put_classification_info(
saliency_map_np: np.ndarray,
indices: List[int],
label_names: List[str] | None,
predictions: Dict[int, Prediction] | None,
) -> None:
corner_location = 3, 17
for smap, target_index in zip(range(len(saliency_map_np)), indices):
label = label_names[target_index] if label_names else str(target_index)
if predictions and target_index in predictions:
score = predictions[target_index].score
if score:
label = f"{label}|{score:.2f}"

cv2.putText(
saliency_map_np[smap],
label,
org=corner_location,
fontFace=1,
fontScale=1.3,
color=(255, 0, 0),
thickness=2,
)

@staticmethod
def _put_detection_info(
saliency_map_np: np.ndarray,
indices: List[int],
label_names: List[str] | None,
predictions: Dict[int, Prediction] | None,
) -> None:
if not predictions:
return

for smap, target_index in zip(range(len(saliency_map_np)), indices):
saliency_map = saliency_map_np[smap]
label_index = predictions[target_index].label
score = predictions[target_index].score
box = predictions[target_index].bounding_box

x1, y1, x2, y2 = np.array(box, dtype=np.int32)
cv2.rectangle(saliency_map, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)

label = label_names[label_index] if label_names else label_index
label_score = f"{label}|{score:.2f}"
box_location = int(x1), int(y1 - 5)
cv2.putText(
saliency_map,
label_score,
org=box_location,
fontFace=1,
fontScale=1.3,
color=(255, 0, 0),
thickness=2,
)

@staticmethod
def _apply_scaling(explanation: Explanation, saliency_map_np: np.ndarray) -> np.ndarray:
Expand Down
14 changes: 10 additions & 4 deletions openvino_xai/methods/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import collections
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Mapping
from dataclasses import dataclass
from typing import Callable, Dict, List, Mapping, Tuple

import numpy as np
import openvino as ov

from openvino_xai.common.parameters import Task
from openvino_xai.common.utils import IdentityPreprocessFN


Expand All @@ -25,7 +24,7 @@ def __init__(
self._model_compiled = None
self.preprocess_fn = preprocess_fn
self._device_name = device_name
self.metadata: Dict[Task, Any] = collections.defaultdict(dict)
self.predictions: Dict[int, Prediction] = {}

@property
def model_compiled(self) -> ov.CompiledModel | None:
Expand All @@ -50,3 +49,10 @@ def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np.
def load_model(self) -> None:
core = ov.Core()
self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name)


@dataclass
class Prediction:
label: int | None = None
score: float | None = None
bounding_box: List | Tuple | None = None
5 changes: 3 additions & 2 deletions openvino_xai/methods/black_box/aise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def __init__(
device_name: str = "CPU",
prepare_model: bool = True,
):
super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name)
self.postprocess_fn = postprocess_fn
super().__init__(
model=model, postprocess_fn=postprocess_fn, preprocess_fn=preprocess_fn, device_name=device_name
)

self.data_preprocessed = None
self.target: int | None = None
Expand Down
13 changes: 10 additions & 3 deletions openvino_xai/methods/black_box/aise/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
scaling,
sigmoid,
)
from openvino_xai.methods.base import Prediction
from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.utils import check_classification_output
Expand Down Expand Up @@ -91,11 +92,12 @@ def generate_saliency_map( # type: ignore
"""
self.data_preprocessed = self.preprocess_fn(data)

logits = self.get_logits(self.data_preprocessed)
if target_indices is None:
num_classes = self.get_num_classes(self.data_preprocessed)
if num_classes > 10:
logger.info(f"num_classes = {num_classes}, which might take significant time to process.")
num_classes = logits.shape[1]
target_indices = list(range(num_classes))
if len(target_indices) > 10:
logger.info(f"{len(target_indices)} targets to process, which might take significant time.")

self.num_iterations_per_kernel, self.kernel_widths = self._preset_parameters(
preset,
Expand All @@ -110,6 +112,7 @@ def generate_saliency_map( # type: ignore
self._mask_generator = GaussianPerturbationMask(self.input_size)

saliency_maps = {}
self.predictions = {}
for target in target_indices:
self.kernel_params_hist = collections.defaultdict(list)
self.pred_score_hist = collections.defaultdict(list)
Expand All @@ -119,6 +122,10 @@ def generate_saliency_map( # type: ignore
if scale_output:
saliency_map_per_target = scaling(saliency_map_per_target)
saliency_maps[target] = saliency_map_per_target
self.predictions[target] = Prediction(
label=target,
score=logits[0][target],
)
return saliency_maps

@staticmethod
Expand Down
Loading

0 comments on commit f637a5f

Please sign in to comment.