Skip to content

Commit

Permalink
Support prediction info overlay
Browse files Browse the repository at this point in the history
  • Loading branch information
negvet committed Sep 2, 2024
1 parent cf77ec4 commit 1a9fe97
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 39 deletions.
3 changes: 2 additions & 1 deletion openvino_xai/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,9 @@ def explain(
explanation = Explanation(
saliency_map=saliency_map,
targets=targets,
task=self.task,
label_names=label_names,
metadata=self.method.metadata,
predictions=self.method.predictions,
)
return self._visualize(
original_input_image,
Expand Down
12 changes: 9 additions & 3 deletions openvino_xai/explainer/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
explains_all,
get_target_indices,
)
from openvino_xai.methods.base import Prediction


class Explanation:
Expand All @@ -36,8 +37,9 @@ def __init__(
self,
saliency_map: np.ndarray | Dict[int | str, np.ndarray],
targets: np.ndarray | List[int | str] | int | str,
task: Task,
label_names: List[str] | None = None,
metadata: Dict[Task, Any] | None = None,
predictions: Dict[Task, Prediction] | None = None,
):
targets = convert_targets_to_numpy(targets)

Expand All @@ -57,10 +59,14 @@ def __init__(
self.layout = Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY

if not explains_all(targets) and not self.layout == Layout.ONE_MAP_PER_IMAGE_GRAY:
self._saliency_map = self._select_target_saliency_maps(targets, label_names)
if task == Task.DETECTION:
self._saliency_map = self._select_target_saliency_maps(targets, None)
else:
self._saliency_map = self._select_target_saliency_maps(targets, label_names)

self.task = task
self.label_names = label_names
self.metadata = metadata
self.predictions = predictions

@property
def saliency_map(self) -> Dict[int | str, np.ndarray]:
Expand Down
69 changes: 47 additions & 22 deletions openvino_xai/explainer/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Explanation,
Layout,
)
from openvino_xai.methods.base import Prediction


def resize(saliency_map: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
Expand Down Expand Up @@ -147,7 +148,10 @@ def visualize(
saliency_map_np = self._apply_overlay(
explanation, saliency_map_np, original_input_image, output_size, overlay_weight
)
saliency_map_np = self._apply_metadata(explanation.metadata, saliency_map_np, indices_to_return)
if explanation.task == Task.CLASSIFICATION:
self._put_classification_info(saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions)
if explanation.task == Task.DETECTION and explanation.predictions:
self._put_detection_info(saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions)
else:
if resize:
if original_input_image is None and output_size is None:
Expand All @@ -162,27 +166,48 @@ def visualize(
return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, indices_to_return)

@staticmethod
def _apply_metadata(metadata: Dict[Task, Any], saliency_map_np: np.ndarray, indices: List[int | str]):
# TODO (negvet): support when indices are strings
if metadata:
if Task.DETECTION in metadata:
for smap_i, target_index in zip(range(len(saliency_map_np)), indices):
saliency_map = saliency_map_np[smap_i]
box, score, label_index = metadata[Task.DETECTION][target_index]
x1, y1, x2, y2 = box
cv2.rectangle(saliency_map, (int(x1), int(y1)), (int(x2), int(y2)), color=(255, 0, 0), thickness=2)
box_label = f"{label_index}|{score:.2f}"
box_label_loc = int(x1), int(y1 - 5)
cv2.putText(
saliency_map,
box_label,
org=box_label_loc,
fontFace=1,
fontScale=1,
color=(255, 0, 0),
thickness=2,
)
return saliency_map_np
def _put_classification_info(saliency_map_np: np.ndarray, indices: List[int | str], label_names: List[str] | None, predictions: Dict[int, Prediction]) -> None:
for smap, target_index in zip(range(len(saliency_map_np)), indices):
corner_location = 3, 17
label = label_names[target_index] if label_names else str(target_index)
if predictions and target_index in predictions:
score = predictions[target_index].score
if score:
label = f"{label}|{score:.2f}"

cv2.putText(
saliency_map_np[smap],
label,
org=corner_location,
fontFace=1,
fontScale=1.3,
color=(255, 0, 0),
thickness=2,
)

@staticmethod
def _put_detection_info(saliency_map_np: np.ndarray, indices: List[int | str], label_names: List[str] | None, predictions: Dict[int, Prediction]) -> None:
for smap, target_index in zip(range(len(saliency_map_np)), indices):
saliency_map = saliency_map_np[smap]
label_index = predictions[target_index].label
score = predictions[target_index].score
box = predictions[target_index].bounding_box

x1, y1, x2, y2 = box
cv2.rectangle(saliency_map, (int(x1), int(y1)), (int(x2), int(y2)), color=(255, 0, 0), thickness=2)

label = label_names[label_index] if label_names else label_index
label_score = f"{label}|{score:.2f}"
box_location = int(x1), int(y1 - 5)
cv2.putText(
saliency_map,
label_score,
org=box_location,
fontFace=1,
fontScale=1.3,
color=(255, 0, 0),
thickness=2,
)

@staticmethod
def _apply_scaling(explanation: Explanation, saliency_map_np: np.ndarray) -> np.ndarray:
Expand Down
12 changes: 10 additions & 2 deletions openvino_xai/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import collections
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Mapping
from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Tuple

import numpy as np
import openvino as ov
Expand All @@ -25,7 +26,7 @@ def __init__(
self._model_compiled = None
self.preprocess_fn = preprocess_fn
self._device_name = device_name
self.metadata: Dict[Task, Any] = collections.defaultdict(dict)
self.predictions = {}

@property
def model_compiled(self) -> ov.CompiledModel | None:
Expand All @@ -50,3 +51,10 @@ def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np.
def load_model(self) -> None:
core = ov.Core()
self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name)


@dataclass
class Prediction:
label: int | None = None
score: float | None = None
bounding_box: Tuple | None = None
13 changes: 10 additions & 3 deletions openvino_xai/methods/black_box/aise/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
scaling,
sigmoid,
)
from openvino_xai.methods.base import Prediction
from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.utils import check_classification_output
Expand Down Expand Up @@ -91,11 +92,12 @@ def generate_saliency_map( # type: ignore
"""
self.data_preprocessed = self.preprocess_fn(data)

logits = self.get_logits(self.data_preprocessed)
if target_indices is None:
num_classes = self.get_num_classes(self.data_preprocessed)
if num_classes > 10:
logger.info(f"num_classes = {num_classes}, which might take significant time to process.")
num_classes = logits.shape[1]
target_indices = list(range(num_classes))
if len(target_indices) > 10:
logger.info(f"{len(target_indices)} targets to process, which might take significant time.")

self.num_iterations_per_kernel, self.kernel_widths = self._preset_parameters(
preset,
Expand All @@ -110,6 +112,7 @@ def generate_saliency_map( # type: ignore
self._mask_generator = GaussianPerturbationMask(self.input_size)

saliency_maps = {}
self.predictions = {}
for target in target_indices:
self.kernel_params_hist = collections.defaultdict(list)
self.pred_score_hist = collections.defaultdict(list)
Expand All @@ -119,6 +122,10 @@ def generate_saliency_map( # type: ignore
if scale_output:
saliency_map_per_target = scaling(saliency_map_per_target)
saliency_maps[target] = saliency_map_per_target
self.predictions[target] = Prediction(
label=target,
score=logits[0][target],
)
return saliency_maps

@staticmethod
Expand Down
14 changes: 10 additions & 4 deletions openvino_xai/methods/black_box/aise/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
logger,
scaling,
)
from openvino_xai.methods.base import Prediction
from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.utils import check_detection_output
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
prepare_model=prepare_model,
)
self.deletion = False
self.predictions = {}

def generate_saliency_map( # type: ignore
self,
Expand Down Expand Up @@ -120,7 +122,7 @@ def generate_saliency_map( # type: ignore
self._mask_generator = GaussianPerturbationMask(self.input_size)

saliency_maps = {}
self.metadata: Dict[Task, Any] = collections.defaultdict(dict)
self.predictions = {}
for target in target_indices:
self.target_box = boxes[target]
self.target_label = labels[target]
Expand All @@ -137,7 +139,7 @@ def generate_saliency_map( # type: ignore
saliency_map_per_target = scaling(saliency_map_per_target)
saliency_maps[target] = saliency_map_per_target

self._update_metadata(boxes, scores, labels, target, original_size)
self._update_predictions(boxes, scores, labels, target, original_size)
return saliency_maps

@staticmethod
Expand Down Expand Up @@ -205,7 +207,7 @@ def _iou(box1: np.ndarray | List[float], box2: np.ndarray | List[float]) -> floa
area2 = np.prod(box2[2:] - box2[:2])
return intersection / (area1 + area2 - intersection)

def _update_metadata(
def _update_predictions(
self,
boxes: np.ndarray | List,
scores: np.ndarray | List[float],
Expand All @@ -218,4 +220,8 @@ def _update_metadata(
height_scale = original_size[0] / self.input_size[0]
x1, x2 = x1 * width_scale, x2 * width_scale
y1, y2 = y1 * height_scale, y2 * height_scale
self.metadata[Task.DETECTION][target] = [x1, y1, x2, y2], scores[target], labels[target]
self.predictions[target] = Prediction(
label=labels[target],
score=scores[target],
bounding_box=[x1, y1, x2, y2],
)
7 changes: 4 additions & 3 deletions openvino_xai/methods/black_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from enum import Enum

import numpy as np
import openvino.runtime as ov

from openvino_xai.methods.base import MethodBase
Expand All @@ -18,12 +19,12 @@ def prepare_model(self, load_model: bool = True) -> ov.Model:
self.load_model()
return self._model

def get_num_classes(self, data_preprocessed):
"""Estimates number of classes for the classification model. Expects batch dimention."""
def get_logits(self, data_preprocessed: np.ndarray) -> np.ndarray:
"""Gets logits for the classification model. Expects batch dimention."""
forward_output = self.model_forward(data_preprocessed, preprocess=False)
logits = self.postprocess_fn(forward_output)
check_classification_output(logits)
return logits.shape[1]
return logits


class Preset(Enum):
Expand Down
9 changes: 8 additions & 1 deletion openvino_xai/methods/black_box/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tqdm import tqdm

from openvino_xai.common.utils import IdentityPreprocessFN, is_bhwc_layout, scaling
from openvino_xai.methods.base import Prediction
from openvino_xai.methods.black_box.base import BlackBoxXAIMethod, Preset
from openvino_xai.methods.black_box.utils import check_classification_output

Expand Down Expand Up @@ -132,9 +133,10 @@ def _run_synchronous_explanation(
) -> np.ndarray:
input_size = data_preprocessed.shape[1:3] if is_bhwc_layout(data_preprocessed) else data_preprocessed.shape[2:4]

num_classes = self.get_num_classes(data_preprocessed)
logits = self.get_logits(data_preprocessed)

if target_classes is None:
num_classes = logits.shape[1]
num_targets = num_classes
else:
num_targets = len(target_classes)
Expand All @@ -159,6 +161,11 @@ def _run_synchronous_explanation(

if target_classes is not None:
saliency_maps = self._reformat_as_dict(saliency_maps, target_classes)
for target in target_classes:
self.predictions[target] = Prediction(
label=target,
score=logits[0][target],
)
return saliency_maps

@staticmethod
Expand Down

0 comments on commit 1a9fe97

Please sign in to comment.