Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insertion Deletion AUC metric #56

Merged
merged 9 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
* Upgrade OpenVINO to 2024.3.0 by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/52
* Add saliency map visualization with explanation.plot() by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/53
* Enable flexible naming for saved saliency maps and include confidence scores by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/51
* Add [Pointing Game](https://arxiv.org/abs/1608.00507) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/54
* Add [Pointing Game](https://link.springer.com/article/10.1007/s11263-017-1059-x) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/54
* Add [Insertion Deletion AUC](https://arxiv.org/abs/1806.07421) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/56

### Known Issues

Expand Down
27 changes: 17 additions & 10 deletions openvino_xai/explainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,29 @@
)


def postprocess_fn(x: Mapping, logit_name="logits") -> np.ndarray:
"""Postprocess function."""
return x.get(logit_name, x[0]) # Models from OVC has no output names at times


def get_postprocess_fn(logit_name="logits") -> Callable[[], np.ndarray]:
"""Returns partially initialized postprocess_fn."""
return partial(postprocess_fn, logit_name=logit_name)


class ActivationType(Enum):
SIGMOID = "sigmoid"
SOFTMAX = "softmax"
NONE = "none"


def postprocess_fn(x: Mapping, logit_name="logits", activation: ActivationType = ActivationType.NONE) -> np.ndarray:
"""Postprocess function."""
x = x.get(logit_name, x[0]) # Models from OVC has no output names at times
if activation == ActivationType.SOFTMAX:
return softmax(x)

Check warning on line 132 in openvino_xai/explainer/utils.py

View check run for this annotation

Codecov / codecov/patch

openvino_xai/explainer/utils.py#L132

Added line #L132 was not covered by tests
if activation == ActivationType.SIGMOID:
return sigmoid(x)
return x


def get_postprocess_fn(
logit_name="logits", activation: ActivationType = ActivationType.NONE
) -> Callable[[], np.ndarray]:
"""Returns partially initialized postprocess_fn."""
return partial(postprocess_fn, logit_name=logit_name, activation=activation)


def get_score(x: np.ndarray, index: int, activation: ActivationType = ActivationType.NONE):
"""Returns activated score at index."""
if activation == ActivationType.SOFTMAX:
Expand Down
4 changes: 2 additions & 2 deletions openvino_xai/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Callable, Dict, Mapping

import numpy as np
import openvino.runtime as ov
import openvino as ov

from openvino_xai.common.utils import IdentityPreprocessFN

Expand All @@ -25,7 +25,7 @@ def __init__(
self._device_name = device_name

@property
def model_compiled(self) -> ov.ie_api.CompiledModel | None:
def model_compiled(self) -> ov.CompiledModel | None:
return self._model_compiled

@abstractmethod
Expand Down
36 changes: 36 additions & 0 deletions openvino_xai/metrics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List

import numpy as np
import openvino as ov

from openvino_xai.common.utils import IdentityPreprocessFN
from openvino_xai.explainer.explanation import Explanation


class BaseMetric(ABC):
"""Base class for XAI quality metric."""

def __init__(
self,
model_compiled: ov.CompiledModel = None,
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
postprocess_fn: Callable[[np.ndarray], np.ndarray] = None,
):
# Pass model_predict to class initialization directly?
self.model_compiled = model_compiled
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn

def model_predict(self, input: np.ndarray) -> np.ndarray:
logits = self.model_compiled([self.preprocess_fn(input)])
logits = self.postprocess_fn(logits)[0]
return logits

@abstractmethod
def __call__(self, saliency_map, *args: Any, **kwargs: Any) -> Dict[str, float]:
"""Calculate the metric for the single saliency map"""

@abstractmethod
def evaluate(self, explanations: List[Explanation], *args: Any, **kwargs: Any) -> Dict[str, float]:
"""Evaluate the quality of saliency maps over the list of images"""
109 changes: 109 additions & 0 deletions openvino_xai/metrics/insertion_deletion_auc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Any, Dict, List, Tuple

import numpy as np

from openvino_xai.explainer.explanation import Explanation, Layout
from openvino_xai.metrics.base import BaseMetric


def AUC(arr: np.array) -> float:
"""
Returns normalized Area Under Curve of the array.
"""
return np.abs((arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1))


class InsertionDeletionAUC(BaseMetric):
"""
Implementation of the Insertion and Deletion AUC by Petsiuk et al. 2018.

References:
Petsiuk, Vitali, Abir Das, and Kate Saenko. "Rise: Randomized input sampling
for explanation of black-box models." arXiv preprint arXiv:1806.07421 (2018).
"""

@staticmethod
def step_image_insertion_deletion(
num_pixels: int, sorted_indices: Tuple[np.ndarray, np.ndarray], input_image: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Return insertion/deletion image based on number of pixels to add/delete on this step.
"""
# Values to start
image_insertion = np.full_like(input_image, 0)
image_deletion = input_image.copy()

x_indices = sorted_indices[0][:num_pixels]
y_indices = sorted_indices[1][:num_pixels]

# Insert the image on the places of the important pixels
image_insertion[x_indices, y_indices] = input_image[x_indices, y_indices]
# Remove image pixels on the places of the important pixels
image_deletion[x_indices, y_indices] = 0
return image_insertion, image_deletion

def __call__(
self, saliency_map: np.ndarray, class_idx: int, input_image: np.ndarray, steps: int = 100, **kwargs: Any
) -> Dict[str, float]:
"""
Calculate the Insertion and Deletion AUC metrics for one saliency map for one class.

Parameters:
:param saliency_map: Importance scores for each pixel (H, W).
:type saliency_map: np.ndarray
:param class_idx: The class of saliency map to evaluate.
:type class_idx: int
:param input_image: The input image to the model (H, W, C).
:type input_image: np.ndarray
:param steps: Number of steps for inserting pixels.
:type steps: int

Returns:
:return: A dictionary containing the AUC scores for insertion and deletion scores.
:rtype: Dict[str, float]
"""
# Sort pixels by descending importance to find the most important pixels
sorted_indices = np.argsort(-saliency_map.flatten())
sorted_indices = np.unravel_index(sorted_indices, saliency_map.shape)

insertion_scores, deletion_scores = [], []
for i in range(steps + 1):
num_pixels = int(i * len(sorted_indices[0]) / steps)
step_image_insertion, step_image_deletion = self.step_image_insertion_deletion(
num_pixels, sorted_indices, input_image
)
# Predict on masked image
insertion_scores.append(self.model_predict(step_image_insertion)[class_idx])
deletion_scores.append(self.model_predict(step_image_deletion)[class_idx])
insertion = AUC(np.array(insertion_scores))
deletion = AUC(np.array(deletion_scores))
return {"insertion": insertion, "deletion": deletion}

def evaluate(
self, explanations: List[Explanation], input_images: List[np.ndarray], steps: int, **kwargs: Any
) -> Dict[str, float]:
"""
Evaluate the insertion and deletion AUC over the list of images and its saliency maps.

:param explanations: List of explanation objects containing saliency maps.
:type explanations: List[Explanation]
:param input_images: List of input images as numpy arrays.
:type input_images: List[np.ndarray]
:param steps: Number of steps for the insertion and deletion process.
:type steps: int

:return: A Dict containing the mean insertion AUC, mean deletion AUC, and their difference (delta) as values.
:rtype: float
"""
for explanation in explanations:
assert explanation.layout in [Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY, Layout.MULTIPLE_MAPS_PER_IMAGE_COLOR]

results = []
for input_image, explanation in zip(input_images, explanations):
for class_idx, saliency_map in explanation.saliency_map.items():
metric_dict = self(saliency_map, int(class_idx), input_image, steps)
results.append([metric_dict["insertion"], metric_dict["deletion"]])

insertion, deletion = np.mean(np.array(results), axis=0)
delta = insertion - deletion
return {"insertion": insertion, "deletion": deletion, "delta": delta}
40 changes: 24 additions & 16 deletions openvino_xai/metrics/pointing_game.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple

import numpy as np

from openvino_xai.common.utils import logger
from openvino_xai.explainer.explanation import Explanation
from openvino_xai.metrics.base import BaseMetric


class PointingGame:
class PointingGame(BaseMetric):
"""
Implementation of the Pointing Game by Zhang et al., 2018.

Expand All @@ -29,18 +30,21 @@ class PointingGame:
"""

@staticmethod
def pointing_game(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int, int, int]]) -> bool:
def __call__(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int, int, int]]) -> Dict[str, float]:
"""
Implements the Pointing Game metric using a saliency map and bounding boxes of the same image and class.
Returns a boolean indicating if any of the most salient points fall within the ground truth bounding boxes.
Calculate the Pointing Game metric for one saliency map for one class.

This implementation uses a saliency map and bounding boxes of the same image and class.
Returns a dictionary with the result of the Pointing Game metric.
1.0 if any of the most salient points fall within the ground truth bounding boxes, 0.0 otherwise.

:param saliency_map: A 2D numpy array representing the saliency map for the image.
:type saliency_map: np.ndarray
:param image_gt_bboxes: A list of tuples (x, y, w, h) representing the bounding boxes of the ground truth objects.
:type image_gt_bboxes: List[Tuple[int, int, int, int]]

:return: True if any of the most salient points fall within any of the ground truth bounding boxes, False otherwise.
:rtype: bool
:return: A dictionary with the result of the Pointing Game metric.
:rtype: Dict[str, float]
"""
# TODO: Optimize calculation by generating a mask from annotation and finding the intersection
# Find the most salient points in the saliency map
Expand All @@ -51,12 +55,15 @@ def pointing_game(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int
for max_point_y, max_point_x in max_indices:
# Check if this point is within the ground truth bounding box
if x <= max_point_x <= x + w and y <= max_point_y <= y + h:
return True
return False
return {"pointing_game": 1.0}
return {"pointing_game": 0.0}

def evaluate(
self, explanations: List[Explanation], gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]]
) -> float:
self,
explanations: List[Explanation],
gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]],
**kwargs: Any,
) -> Dict[str, float]:
"""
Evaluates the Pointing Game metric over a set of images. Skips saliency maps if the gt bboxes for this class are absent.

Expand All @@ -65,15 +72,15 @@ def evaluate(
:param gt_bboxes: A list of dictionaries {label_name: lists of bounding boxes} for each image.
:type gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]]

:return: Pointing game score over a list of images
:rtype: float
:return: Dict with "Pointing game" as a key and score over a list of images as as a value.
:rtype: Dict[str, float]
"""

assert len(explanations) == len(
gt_bboxes
), "Number of explanations and ground truth bounding boxes must match and equal to number of images."

hits = 0
hits = 0.0
num_sal_maps = 0
for explanation, image_gt_bboxes in zip(explanations, gt_bboxes):
label_names = explanation.label_names
Expand All @@ -90,7 +97,8 @@ def evaluate(
continue

class_gt_bboxes = image_gt_bboxes[label_name]
hits += self.pointing_game(class_sal_map, class_gt_bboxes)
hits += self(class_sal_map, class_gt_bboxes)["pointing_game"]
num_sal_maps += 1

return hits / num_sal_maps if num_sal_maps > 0 else 0.0
score = hits / num_sal_maps if num_sal_maps > 0 else 0.0
return {"pointing_game": score}
Loading