Skip to content

Commit

Permalink
Add ADCC metric (#57)
Browse files Browse the repository at this point in the history
* Draft pointing game implementation

* Add insertion deletion auc

* Add ADCC

* Update auc

* Introduce BaseMetric as a parent class

* Delete ADCC

* Remove adcc tests

* Fixes from comments

* Add ADCC

* Remove scaling logic

* Add extra unit test

* Update threshold value

* Update Changelog
  • Loading branch information
GalyaZalesskaya authored Aug 28, 2024
1 parent 4e39758 commit 4ce1903
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 13 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* Upgrade OpenVINO to 2024.3.0
* Add saliency map visualization with explanation.plot()
* Enable flexible naming for saved saliency maps and include confidence scores
* Add Pointing Game, Insertion-Deletion AUC and ADCC quality metrics for saliency maps

### What's Changed

Expand All @@ -22,7 +23,8 @@
* Add saliency map visualization with explanation.plot() by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/53
* Enable flexible naming for saved saliency maps and include confidence scores by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/51
* Add [Pointing Game](https://link.springer.com/article/10.1007/s11263-017-1059-x) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/54
* Add [Insertion Deletion AUC](https://arxiv.org/abs/1806.07421) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/56
* Add [Insertion-Deletion AUC](https://arxiv.org/abs/1806.07421) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/56
* Add [ADCC](https://arxiv.org/abs/2104.10252) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/57

### Known Issues

Expand Down
6 changes: 6 additions & 0 deletions openvino_xai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@
"""
Metrics in OpenVINO-XAI to check the quality of saliency maps.
"""

from openvino_xai.metrics.adcc import ADCC
from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC
from openvino_xai.metrics.pointing_game import PointingGame

__all__ = ["ADCC", "InsertionDeletionAUC", "PointingGame"]
132 changes: 132 additions & 0 deletions openvino_xai/metrics/adcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import Any, Dict, List

import numpy as np
from scipy import stats as STS

from openvino_xai import Task
from openvino_xai.explainer.explainer import Explainer, ExplainMode
from openvino_xai.explainer.explanation import Explanation
from openvino_xai.metrics.base import BaseMetric


class ADCC(BaseMetric):
"""
Implementation of the e Average Drop-Coherence-Complexity (ADCC) metric by Poppi, Samuele, et al 2021.
References:
Poppi, Samuele, et al. "Revisiting the evaluation of class activation mapping for explainability:
A novel metric and experimental analysis." Proceedings of the IEEE/CVF Conference on
Computer Vision and Pattern Recognition. 2021.
"""

def __init__(self, model, preprocess_fn, postprocess_fn, explainer=None, device_name="CPU"):
super().__init__(
model=model, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, device_name=device_name
)
if explainer is None:
self.explainer = Explainer(
model=model,
task=Task.CLASSIFICATION,
preprocess_fn=self.preprocess_fn,
explain_mode=ExplainMode.WHITEBOX,
)
else:
self.explainer = explainer

def average_drop(
self, saliency_map: np.ndarray, class_idx: int, image: np.ndarray, model_output: np.ndarray
) -> float:
"""
Measures the average percentage drop in confidence for the target class when the model sees only the
explanation map (image masked with saliency map), instead of the full image.
The less the better.
"""
confidence_on_input = np.max(model_output)

masked_image = (image * saliency_map[:, :, None]).astype(np.uint8)
prediction_on_saliency_map = self.model_predict(masked_image)
confidence_on_saliency_map = prediction_on_saliency_map[class_idx]

return max(0.0, confidence_on_input - confidence_on_saliency_map) / confidence_on_input

def coherency(self, saliency_map: np.ndarray, class_idx: int, image: np.ndarray) -> float:
"""
Measures the coherency of the saliency map. The explanation map (image masked with saliency map) should contain all the relevant features that explain a prediction and should remove useless features in a coherent way.
Saliency map and saliency map of exlanation map should be similar.
The more the better.
"""

masked_image = image * saliency_map[:, :, None]
saliency_map_mapped_image = self.explainer(masked_image, targets=[class_idx], colormap=False, scaling=False)
saliency_map_mapped_image = saliency_map_mapped_image.saliency_map[class_idx]

A, B = saliency_map, saliency_map_mapped_image
# Pearson correlation coefficient
Asq, Bsq = A.flatten(), B.flatten()
y, _ = STS.pearsonr(Asq, Bsq)
y = (y + 1) / 2

return y

@staticmethod
def complexity(saliency_map: np.ndarray) -> float:
"""
Measures the complexity of the saliency map. Less important pixels -> less complexity.
Defined as L1 norm of the saliency map.
The less the better.
"""
return abs(saliency_map).sum() / (saliency_map.shape[-1] * saliency_map.shape[-2])

def __call__(self, saliency_map: np.ndarray, class_idx: int, input_image: np.ndarray) -> Dict[str, float]:
"""
Calculate the ADCC metric for a given saliency map and class index.
The more the better.
Parameters:
:param saliency_map: Saliency map for class_idx class (H, W).
:type saliency_map: np.ndarray
:param class_idx: The class index of saliency map.
:type class_idx: int
:param input_image: The input image to the model (H, W, C).
:type input_image: np.ndarray
Returns:
:return: A dictionary containing the ADCC, coherency, complexity, and average drop metrics.
:rtype: Dict[str, float]
"""
if not (0 <= np.min(saliency_map) and np.max(saliency_map) <= 1):
# Scale saliency map to [0, 1]
saliency_map = saliency_map / 255

model_output = self.model_predict(input_image)

avgdrop = self.average_drop(saliency_map, class_idx, input_image, model_output)
coh = self.coherency(saliency_map, class_idx, input_image)
com = self.complexity(saliency_map)

adcc = 3 / (1 / coh + 1 / (1 - com) + 1 / (1 - avgdrop))
return {"adcc": adcc, "coherency": coh, "complexity": com, "average_drop": avgdrop}

def evaluate(
self, explanations: List[Explanation], input_images: List[np.ndarray], **kwargs: Any
) -> Dict[str, float]:
"""
Evaluate the ADCC metric over a list of explanations and input images.
Parameters:
:param explanations: A list of explanations for each image.
:type explanations: List[Explanation]
:param input_images: A list of input images.
:type input_images: List[np.ndarray]
Returns:
:return: A dictionary containing the average ADCC score.
:rtype: Dict[str, float]
"""
results = []
for input_image, explanation in zip(input_images, explanations):
for class_idx, saliency_map in explanation.saliency_map.items():
metric_dict = self(saliency_map, int(class_idx), input_image)
results.append(metric_dict["adcc"])
adcc = np.mean(np.array(results), axis=0)
return {"adcc": adcc}
6 changes: 4 additions & 2 deletions openvino_xai/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ class BaseMetric(ABC):

def __init__(
self,
model_compiled: ov.CompiledModel = None,
model: ov.Model = None,
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
postprocess_fn: Callable[[np.ndarray], np.ndarray] = None,
device_name: str = "CPU",
):
# Pass model_predict to class initialization directly?
self.model_compiled = model_compiled
self.model = model
self.model_compiled = ov.Core().compile_model(model=model, device_name=device_name)
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn

Expand Down
3 changes: 3 additions & 0 deletions openvino_xai/metrics/pointing_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class PointingGame(BaseMetric):
(2018) 126:1084-1102.
"""

def __init__(self):
pass

@staticmethod
def __call__(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int, int, int]]) -> Dict[str, float]:
"""
Expand Down
28 changes: 20 additions & 8 deletions tests/regression/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
get_postprocess_fn,
get_preprocess_fn,
)
from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC
from openvino_xai.metrics.pointing_game import PointingGame
from openvino_xai.metrics import ADCC, InsertionDeletionAUC, PointingGame
from tests.unit.explanation.test_explanation_utils import VOC_NAMES

MODEL_NAME = "mlc_mobilenetv3_large_voc"
Expand Down Expand Up @@ -57,7 +56,6 @@ def load_gt_bboxes(json_coco_path: str) -> List[Dict[str, List[Tuple[int, int, i
class TestDummyRegression:
image = cv2.imread(IMAGE_PATH)
gt_bboxes = load_gt_bboxes(COCO_ANN_PATH)
pointing_game = PointingGame()

preprocess_fn = get_preprocess_fn(
change_channel_order=True,
Expand All @@ -73,8 +71,6 @@ def setup(self, fxt_data_root):
retrieve_otx_model(data_dir, MODEL_NAME)
model_path = data_dir / "otx_models" / (MODEL_NAME + ".xml")
model = ov.Core().read_model(model_path)
compiled_model = ov.Core().compile_model(model, "CPU")
self.auc = InsertionDeletionAUC(compiled_model, self.preprocess_fn, self.postprocess_fn)

self.explainer = Explainer(
model=model,
Expand All @@ -83,29 +79,42 @@ def setup(self, fxt_data_root):
explain_mode=ExplainMode.WHITEBOX,
)

self.pointing_game = PointingGame()
self.auc = InsertionDeletionAUC(model, self.preprocess_fn, self.postprocess_fn)
self.adcc = ADCC(model, self.preprocess_fn, self.postprocess_fn, self.explainer)

def test_explainer_image(self):
explanation = self.explainer(self.image, targets=["person"], label_names=VOC_NAMES, colormap=False)
assert len(explanation.saliency_map) == 1

pointing_game_score = self.pointing_game.evaluate([explanation], self.gt_bboxes)["pointing_game"]
assert pointing_game_score == 1.0

explanation = self.explainer(self.image, targets=["person"], label_names=VOC_NAMES, colormap=False)
assert len(explanation.saliency_map) == 1
auc_score = self.auc.evaluate([explanation], [self.image], steps=10).values()
insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score
assert insertion_auc_score >= 0.9
assert deletion_auc_score >= 0.2
assert delta_auc_score >= 0.7

# Two classes for saliency maps
adcc_score = self.adcc.evaluate([explanation], [self.image])["adcc"]
assert adcc_score > 0.9

def test_explainer_image_2_classes(self):
explanation = self.explainer(self.image, targets=["person", "cat"], label_names=VOC_NAMES, colormap=False)
assert len(explanation.saliency_map) == 2

pointing_game_score = self.pointing_game.evaluate([explanation], self.gt_bboxes)["pointing_game"]
assert pointing_game_score == 1.0

auc_score = self.auc.evaluate([explanation], [self.image], steps=10).values()
insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score
assert insertion_auc_score >= 0.5
assert deletion_auc_score >= 0.1
assert delta_auc_score >= 0.35

adcc_score = self.adcc.evaluate([explanation], [self.image])["adcc"]
assert adcc_score > 0.5

def test_explainer_images(self):
images = [self.image, self.image]
explanations = []
Expand All @@ -122,3 +131,6 @@ def test_explainer_images(self):
assert insertion_auc_score >= 0.9
assert deletion_auc_score >= 0.2
assert delta_auc_score >= 0.7

adcc_score = self.adcc.evaluate(explanations, images)["adcc"]
assert adcc_score > 0.9
83 changes: 83 additions & 0 deletions tests/unit/metrics/test_adcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
from typing import Callable, List, Mapping

import cv2
import numpy as np
import openvino as ov
import pytest

from openvino_xai import Task
from openvino_xai.common.utils import retrieve_otx_model
from openvino_xai.explainer.explainer import Explainer, ExplainMode
from openvino_xai.explainer.explanation import Explanation
from openvino_xai.explainer.utils import (
ActivationType,
get_postprocess_fn,
get_preprocess_fn,
sigmoid,
)
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.metrics.adcc import ADCC
from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC
from openvino_xai.metrics.pointing_game import PointingGame
from tests.unit.explanation.test_explanation_utils import VOC_NAMES

MODEL_NAME = "mlc_mobilenetv3_large_voc"


class TestADCC:
image = cv2.imread("tests/assets/cheetah_person.jpg")
preprocess_fn = get_preprocess_fn(
change_channel_order=True,
input_size=(224, 224),
hwc_to_chw=True,
)
postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID)

@pytest.fixture(autouse=True)
def setup(self, fxt_data_root):
self.data_dir = fxt_data_root
retrieve_otx_model(self.data_dir, MODEL_NAME)
model_path = self.data_dir / "otx_models" / (MODEL_NAME + ".xml")
self.model = ov.Core().read_model(model_path)
self.explainer = Explainer(
model=self.model,
task=Task.CLASSIFICATION,
preprocess_fn=self.preprocess_fn,
explain_mode=ExplainMode.WHITEBOX,
)
self.adcc = ADCC(self.model, self.preprocess_fn, self.postprocess_fn, self.explainer)

def test_adcc_init_wo_explainer(self):
adcc_wo_explainer = ADCC(self.model, self.preprocess_fn, self.postprocess_fn)
assert isinstance(adcc_wo_explainer.explainer, Explainer)

def test_adcc(self):
input_image = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
saliency_map = np.random.rand(224, 224)

complexity_score = self.adcc.complexity(saliency_map)
assert complexity_score >= 0.2

model_output = self.adcc.model_predict(input_image)
class_idx = np.argmax(model_output)

average_drop_score = self.adcc.average_drop(saliency_map, class_idx, input_image, model_output)
assert average_drop_score >= 0.2

coherency_score = self.adcc.coherency(saliency_map, class_idx, input_image)
assert coherency_score >= 0.2

adcc_score = self.adcc(saliency_map, class_idx, input_image)["adcc"]
assert adcc_score >= 0.4

def test_evaluate(self):
input_images = [np.random.rand(224, 224, 3) for _ in range(5)]
explanations = [
Explanation({0: np.random.rand(224, 224), 1: np.random.rand(224, 224)}, targets=[0, 1]) for _ in range(5)
]

adcc_score = self.adcc.evaluate(explanations, input_images)["adcc"]

assert isinstance(adcc_score, float)
assert 0 <= adcc_score <= 1
3 changes: 1 addition & 2 deletions tests/unit/metrics/test_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def setup(self, fxt_data_root):
model_path = self.data_dir / "otx_models" / (MODEL_NAME + ".xml")
core = ov.Core()
model = core.read_model(model_path)
compiled_model = core.compile_model(model=model, device_name="AUTO")
self.auc = InsertionDeletionAUC(compiled_model, self.preprocess_fn, self.postprocess_fn)
self.auc = InsertionDeletionAUC(model, self.preprocess_fn, self.postprocess_fn)

self.explainer = Explainer(
model=model,
Expand Down

0 comments on commit 4ce1903

Please sign in to comment.