Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support per-target Prediction #62

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion openvino_xai/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,9 @@ def explain(
explanation = Explanation(
saliency_map=saliency_map,
targets=targets,
task=self.task,
label_names=label_names,
metadata=self.method.metadata,
predictions=self.method.predictions,
)
return self._visualize(
original_input_image,
Expand Down
12 changes: 9 additions & 3 deletions openvino_xai/explainer/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
explains_all,
get_target_indices,
)
from openvino_xai.methods.base import Prediction


class Explanation:
Expand All @@ -36,8 +37,9 @@ def __init__(
self,
saliency_map: np.ndarray | Dict[int | str, np.ndarray],
targets: np.ndarray | List[int | str] | int | str,
task: Task,
label_names: List[str] | None = None,
metadata: Dict[Task, Any] | None = None,
predictions: Dict[Task, Prediction] | None = None,
negvet marked this conversation as resolved.
Show resolved Hide resolved
):
targets = convert_targets_to_numpy(targets)

Expand All @@ -57,10 +59,14 @@ def __init__(
self.layout = Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY

if not explains_all(targets) and not self.layout == Layout.ONE_MAP_PER_IMAGE_GRAY:
self._saliency_map = self._select_target_saliency_maps(targets, label_names)
if task == Task.DETECTION:
self._saliency_map = self._select_target_saliency_maps(targets, None)
else:
self._saliency_map = self._select_target_saliency_maps(targets, label_names)
negvet marked this conversation as resolved.
Show resolved Hide resolved

self.task = task
self.label_names = label_names
self.metadata = metadata
self.predictions = predictions

@property
def saliency_map(self) -> Dict[int | str, np.ndarray]:
Expand Down
69 changes: 47 additions & 22 deletions openvino_xai/explainer/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Explanation,
Layout,
)
from openvino_xai.methods.base import Prediction


def resize(saliency_map: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
Expand Down Expand Up @@ -147,7 +148,10 @@ def visualize(
saliency_map_np = self._apply_overlay(
explanation, saliency_map_np, original_input_image, output_size, overlay_weight
)
saliency_map_np = self._apply_metadata(explanation.metadata, saliency_map_np, indices_to_return)
if explanation.task == Task.CLASSIFICATION:
self._put_classification_info(saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions)
if explanation.task == Task.DETECTION and explanation.predictions:
self._put_detection_info(saliency_map_np, indices_to_return, explanation.label_names, explanation.predictions)
else:
if resize:
if original_input_image is None and output_size is None:
Expand All @@ -162,27 +166,48 @@ def visualize(
return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, indices_to_return)

@staticmethod
def _apply_metadata(metadata: Dict[Task, Any], saliency_map_np: np.ndarray, indices: List[int | str]):
# TODO (negvet): support when indices are strings
if metadata:
if Task.DETECTION in metadata:
for smap_i, target_index in zip(range(len(saliency_map_np)), indices):
saliency_map = saliency_map_np[smap_i]
box, score, label_index = metadata[Task.DETECTION][target_index]
x1, y1, x2, y2 = box
cv2.rectangle(saliency_map, (int(x1), int(y1)), (int(x2), int(y2)), color=(255, 0, 0), thickness=2)
box_label = f"{label_index}|{score:.2f}"
box_label_loc = int(x1), int(y1 - 5)
cv2.putText(
saliency_map,
box_label,
org=box_label_loc,
fontFace=1,
fontScale=1,
color=(255, 0, 0),
thickness=2,
)
return saliency_map_np
def _put_classification_info(saliency_map_np: np.ndarray, indices: List[int | str], label_names: List[str] | None, predictions: Dict[int, Prediction]) -> None:
for smap, target_index in zip(range(len(saliency_map_np)), indices):
corner_location = 3, 17
negvet marked this conversation as resolved.
Show resolved Hide resolved
label = label_names[target_index] if label_names else str(target_index)
if predictions and target_index in predictions:
score = predictions[target_index].score
if score:
label = f"{label}|{score:.2f}"

cv2.putText(
saliency_map_np[smap],
label,
org=corner_location,
fontFace=1,
fontScale=1.3,
color=(255, 0, 0),
thickness=2,
)

@staticmethod
def _put_detection_info(saliency_map_np: np.ndarray, indices: List[int | str], label_names: List[str] | None, predictions: Dict[int, Prediction]) -> None:
for smap, target_index in zip(range(len(saliency_map_np)), indices):
saliency_map = saliency_map_np[smap]
label_index = predictions[target_index].label
score = predictions[target_index].score
box = predictions[target_index].bounding_box

x1, y1, x2, y2 = box
cv2.rectangle(saliency_map, (int(x1), int(y1)), (int(x2), int(y2)), color=(255, 0, 0), thickness=2)
negvet marked this conversation as resolved.
Show resolved Hide resolved

label = label_names[label_index] if label_names else label_index
label_score = f"{label}|{score:.2f}"
box_location = int(x1), int(y1 - 5)
cv2.putText(
saliency_map,
label_score,
org=box_location,
fontFace=1,
fontScale=1.3,
color=(255, 0, 0),
thickness=2,
)

@staticmethod
def _apply_scaling(explanation: Explanation, saliency_map_np: np.ndarray) -> np.ndarray:
Expand Down
12 changes: 10 additions & 2 deletions openvino_xai/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import collections
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Mapping
from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Tuple

import numpy as np
import openvino as ov
Expand All @@ -25,7 +26,7 @@ def __init__(
self._model_compiled = None
self.preprocess_fn = preprocess_fn
self._device_name = device_name
self.metadata: Dict[Task, Any] = collections.defaultdict(dict)
self.predictions = {}

@property
def model_compiled(self) -> ov.CompiledModel | None:
Expand All @@ -50,3 +51,10 @@ def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np.
def load_model(self) -> None:
core = ov.Core()
self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name)


@dataclass
class Prediction:
label: int | None = None
score: float | None = None
bounding_box: Tuple | None = None
13 changes: 10 additions & 3 deletions openvino_xai/methods/black_box/aise/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
scaling,
sigmoid,
)
from openvino_xai.methods.base import Prediction
from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.utils import check_classification_output
Expand Down Expand Up @@ -91,11 +92,12 @@ def generate_saliency_map( # type: ignore
"""
self.data_preprocessed = self.preprocess_fn(data)

logits = self.get_logits(self.data_preprocessed)
if target_indices is None:
num_classes = self.get_num_classes(self.data_preprocessed)
if num_classes > 10:
logger.info(f"num_classes = {num_classes}, which might take significant time to process.")
num_classes = logits.shape[1]
target_indices = list(range(num_classes))
if len(target_indices) > 10:
logger.info(f"{len(target_indices)} targets to process, which might take significant time.")

self.num_iterations_per_kernel, self.kernel_widths = self._preset_parameters(
preset,
Expand All @@ -110,6 +112,7 @@ def generate_saliency_map( # type: ignore
self._mask_generator = GaussianPerturbationMask(self.input_size)

saliency_maps = {}
self.predictions = {}
for target in target_indices:
self.kernel_params_hist = collections.defaultdict(list)
self.pred_score_hist = collections.defaultdict(list)
Expand All @@ -119,6 +122,10 @@ def generate_saliency_map( # type: ignore
if scale_output:
saliency_map_per_target = scaling(saliency_map_per_target)
saliency_maps[target] = saliency_map_per_target
self.predictions[target] = Prediction(
label=target,
score=logits[0][target],
)
return saliency_maps

@staticmethod
Expand Down
14 changes: 10 additions & 4 deletions openvino_xai/methods/black_box/aise/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
logger,
scaling,
)
from openvino_xai.methods.base import Prediction
from openvino_xai.methods.black_box.aise.base import AISEBase, GaussianPerturbationMask
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.utils import check_detection_output
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
prepare_model=prepare_model,
)
self.deletion = False
self.predictions = {}

def generate_saliency_map( # type: ignore
self,
Expand Down Expand Up @@ -120,7 +122,7 @@ def generate_saliency_map( # type: ignore
self._mask_generator = GaussianPerturbationMask(self.input_size)

saliency_maps = {}
self.metadata: Dict[Task, Any] = collections.defaultdict(dict)
self.predictions = {}
for target in target_indices:
self.target_box = boxes[target]
self.target_label = labels[target]
Expand All @@ -137,7 +139,7 @@ def generate_saliency_map( # type: ignore
saliency_map_per_target = scaling(saliency_map_per_target)
saliency_maps[target] = saliency_map_per_target

self._update_metadata(boxes, scores, labels, target, original_size)
self._update_predictions(boxes, scores, labels, target, original_size)
return saliency_maps

@staticmethod
Expand Down Expand Up @@ -205,7 +207,7 @@ def _iou(box1: np.ndarray | List[float], box2: np.ndarray | List[float]) -> floa
area2 = np.prod(box2[2:] - box2[:2])
return intersection / (area1 + area2 - intersection)

def _update_metadata(
def _update_predictions(
self,
boxes: np.ndarray | List,
scores: np.ndarray | List[float],
Expand All @@ -218,4 +220,8 @@ def _update_metadata(
height_scale = original_size[0] / self.input_size[0]
x1, x2 = x1 * width_scale, x2 * width_scale
y1, y2 = y1 * height_scale, y2 * height_scale
self.metadata[Task.DETECTION][target] = [x1, y1, x2, y2], scores[target], labels[target]
self.predictions[target] = Prediction(
label=labels[target],
score=scores[target],
bounding_box=[x1, y1, x2, y2],
)
7 changes: 4 additions & 3 deletions openvino_xai/methods/black_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from enum import Enum

import numpy as np
import openvino.runtime as ov

from openvino_xai.methods.base import MethodBase
Expand All @@ -18,12 +19,12 @@ def prepare_model(self, load_model: bool = True) -> ov.Model:
self.load_model()
return self._model

def get_num_classes(self, data_preprocessed):
"""Estimates number of classes for the classification model. Expects batch dimention."""
def get_logits(self, data_preprocessed: np.ndarray) -> np.ndarray:
"""Gets logits for the classification model. Expects batch dimention."""
forward_output = self.model_forward(data_preprocessed, preprocess=False)
logits = self.postprocess_fn(forward_output)
check_classification_output(logits)
return logits.shape[1]
return logits


class Preset(Enum):
Expand Down
9 changes: 8 additions & 1 deletion openvino_xai/methods/black_box/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tqdm import tqdm

from openvino_xai.common.utils import IdentityPreprocessFN, is_bhwc_layout, scaling
from openvino_xai.methods.base import Prediction
from openvino_xai.methods.black_box.base import BlackBoxXAIMethod, Preset
from openvino_xai.methods.black_box.utils import check_classification_output

Expand Down Expand Up @@ -132,9 +133,10 @@ def _run_synchronous_explanation(
) -> np.ndarray:
input_size = data_preprocessed.shape[1:3] if is_bhwc_layout(data_preprocessed) else data_preprocessed.shape[2:4]

num_classes = self.get_num_classes(data_preprocessed)
logits = self.get_logits(data_preprocessed)

if target_classes is None:
num_classes = logits.shape[1]
num_targets = num_classes
else:
num_targets = len(target_classes)
Expand All @@ -159,6 +161,11 @@ def _run_synchronous_explanation(

if target_classes is not None:
saliency_maps = self._reformat_as_dict(saliency_maps, target_classes)
for target in target_classes:
self.predictions[target] = Prediction(
label=target,
score=logits[0][target],
)
return saliency_maps

@staticmethod
Expand Down
Loading