-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pointing game accuracy metrics & draft regression tests implementation (
#54) * Draft pointing game implementation * Add insertion deletion auc * Add ADCC * Remove adcc and auc metrics * Support milti class ann and several bb per image * Reformat annotation for better readability * Minor format * Remove utils * Fixes from comments * Minor
- Loading branch information
1 parent
ccfadb9
commit 1e41ff2
Showing
8 changed files
with
337 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
""" | ||
Metrics in OpenVINO-XAI to check the quality of saliency maps. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Dict, List, Tuple | ||
|
||
import numpy as np | ||
|
||
from openvino_xai.common.utils import logger | ||
from openvino_xai.explainer.explanation import Explanation | ||
|
||
|
||
class PointingGame: | ||
""" | ||
Implementation of the Pointing Game by Zhang et al., 2018. | ||
Unlike the original approach that uses ground truth bounding masks, this implementation uses ground | ||
truth bounding boxes. The Pointing Game checks whether the most salient point is within the annotated | ||
object. High scores mean that the most salient pixel belongs to an object of the specified class. | ||
References: | ||
1) Reference implementation: | ||
https://github.com/understandable-machine-intelligence-lab/Quantus/ | ||
Hedström, Anna, et al.: | ||
"Quantus: An explainable ai toolkit for responsible evaluation of neural network explanations and beyond." | ||
Journal of Machine Learning Research 24.34 (2023): 1-11. | ||
2) Jianming Zhang et al.: | ||
"Top-Down Neural Attention by Excitation Backprop." International Journal of Computer Vision | ||
(2018) 126:1084-1102. | ||
""" | ||
|
||
@staticmethod | ||
def pointing_game(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int, int, int]]) -> bool: | ||
""" | ||
Implements the Pointing Game metric using a saliency map and bounding boxes of the same image and class. | ||
Returns a boolean indicating if any of the most salient points fall within the ground truth bounding boxes. | ||
:param saliency_map: A 2D numpy array representing the saliency map for the image. | ||
:type saliency_map: np.ndarray | ||
:param image_gt_bboxes: A list of tuples (x, y, w, h) representing the bounding boxes of the ground truth objects. | ||
:type image_gt_bboxes: List[Tuple[int, int, int, int]] | ||
:return: True if any of the most salient points fall within any of the ground truth bounding boxes, False otherwise. | ||
:rtype: bool | ||
""" | ||
# TODO: Optimize calculation by generating a mask from annotation and finding the intersection | ||
# Find the most salient points in the saliency map | ||
max_indices = np.argwhere(saliency_map == np.max(saliency_map)) | ||
|
||
# If multiple bounding boxes are available for one image | ||
for x, y, w, h in image_gt_bboxes: | ||
for max_point_y, max_point_x in max_indices: | ||
# Check if this point is within the ground truth bounding box | ||
if x <= max_point_x <= x + w and y <= max_point_y <= y + h: | ||
return True | ||
return False | ||
|
||
def evaluate( | ||
self, explanations: List[Explanation], gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]] | ||
) -> float: | ||
""" | ||
Evaluates the Pointing Game metric over a set of images. Skips saliency maps if the gt bboxes for this class are absent. | ||
:param explanations: A list of explanations for each image. | ||
:type explanations: List[Explanation] | ||
:param gt_bboxes: A list of dictionaries {label_name: lists of bounding boxes} for each image. | ||
:type gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]] | ||
:return: Pointing game score over a list of images | ||
:rtype: float | ||
""" | ||
|
||
assert len(explanations) == len( | ||
gt_bboxes | ||
), "Number of explanations and ground truth bounding boxes must match and equal to number of images." | ||
|
||
hits = 0 | ||
num_sal_maps = 0 | ||
for explanation, image_gt_bboxes in zip(explanations, gt_bboxes): | ||
label_names = explanation.label_names | ||
assert label_names is not None, "Label names are required for pointing game evaluation." | ||
|
||
for class_idx, class_sal_map in explanation.saliency_map.items(): | ||
label_name = label_names[int(class_idx)] | ||
|
||
if label_name not in image_gt_bboxes: | ||
logger.info( | ||
f"No ground-truth bbox for {label_name} saliency map. " | ||
f"Skip pointing game evaluation for this saliency map." | ||
) | ||
continue | ||
|
||
class_gt_bboxes = image_gt_bboxes[label_name] | ||
hits += self.pointing_game(class_sal_map, class_gt_bboxes) | ||
num_sal_maps += 1 | ||
|
||
return hits / num_sal_maps if num_sal_maps > 0 else 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
{ | ||
"licenses": [{"name": "", "id": 0, "url": ""}], | ||
"info": {"contributor": "", "date_created": "", "description": "", "url": "", "version": "", "year": ""}, | ||
"categories": [ | ||
{ | ||
"id": 1, | ||
"name": "person", | ||
"supercategory": "" | ||
}, | ||
{ | ||
"id": 2, | ||
"name": "cheetah", | ||
"supercategory": "" | ||
} | ||
], | ||
"images": [ | ||
{ | ||
"id": 1, | ||
"width": 500, | ||
"height": 354, | ||
"file_name": "cheetah_person.jpg", | ||
"license": 0, | ||
"flickr_url": "", | ||
"coco_url": "", | ||
"date_captured": 0 | ||
} | ||
], | ||
"annotations": [ | ||
{ | ||
"id": 1, | ||
"image_id": 1, | ||
"category_id": 1, | ||
"segmentation": [], | ||
"area": 30560.0, | ||
"bbox": [274.0, 99.0, 160.0, 191.0], | ||
"iscrowd": 0 | ||
}, | ||
{ | ||
"id": 2, | ||
"image_id": 1, | ||
"category_id": 2, | ||
"segmentation": [], | ||
"area": 37281.0, | ||
"bbox": [17.0, 160.0, 289.0, 129.0], | ||
"iscrowd": 0 | ||
}, | ||
{ | ||
"id": 3, | ||
"image_id": 1, | ||
"category_id": 2, | ||
"segmentation": [], | ||
"area": 16786.0, | ||
"bbox": [165.0, 129.0, 109.0, 154.0], | ||
"iscrowd": 0 | ||
}, | ||
{ | ||
"id": 4, | ||
"image_id": 1, | ||
"category_id": 2, | ||
"segmentation": [], | ||
"area": 26316.0, | ||
"bbox": [316.0, 111.0, 153.0, 172.0], | ||
"iscrowd": 0 | ||
} | ||
] | ||
} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
from typing import Dict, List, Tuple | ||
|
||
import cv2 | ||
import openvino as ov | ||
import pytest | ||
|
||
from openvino_xai import Task | ||
from openvino_xai.common.utils import retrieve_otx_model | ||
from openvino_xai.explainer.explainer import Explainer, ExplainMode | ||
from openvino_xai.explainer.utils import get_preprocess_fn | ||
from openvino_xai.metrics.pointing_game import PointingGame | ||
from tests.unit.explanation.test_explanation_utils import VOC_NAMES | ||
|
||
MODEL_NAME = "mlc_mobilenetv3_large_voc" | ||
IMAGE_PATH = "tests/assets/cheetah_person.jpg" | ||
COCO_ANN_PATH = "tests/assets/cheetah_person_coco.json" | ||
|
||
|
||
def load_gt_bboxes(json_coco_path: str) -> List[Dict[str, List[Tuple[int, int, int, int]]]]: | ||
""" | ||
Loads ground truth bounding boxes from a COCO format JSON file. | ||
Returns a list of dictionaries, where each dictionary corresponds to an image. | ||
The key is the label name and the value is a list of bounding boxes for certain image. | ||
""" | ||
|
||
with open(json_coco_path, "r") as f: | ||
coco_anns = json.load(f) | ||
|
||
result = {} | ||
category_id_to_name = {category["id"]: category["name"] for category in coco_anns["categories"]} | ||
|
||
for annotation in coco_anns["annotations"]: | ||
image_id = annotation["image_id"] | ||
category_id = annotation["category_id"] | ||
bbox = annotation["bbox"] | ||
|
||
category_name = category_id_to_name[category_id] | ||
if image_id not in result: | ||
result[image_id] = {} | ||
if category_name not in result[image_id]: | ||
result[image_id][category_name] = [] | ||
|
||
result[image_id][category_name].append(bbox) | ||
|
||
return list(result.values()) | ||
|
||
|
||
class TestDummyRegression: | ||
image = cv2.imread(IMAGE_PATH) | ||
gt_bboxes = load_gt_bboxes(COCO_ANN_PATH) | ||
pointing_game = PointingGame() | ||
|
||
preprocess_fn = get_preprocess_fn( | ||
change_channel_order=True, | ||
input_size=(224, 224), | ||
hwc_to_chw=True, | ||
) | ||
|
||
@pytest.fixture(autouse=True) | ||
def setup(self, fxt_data_root): | ||
data_dir = fxt_data_root | ||
retrieve_otx_model(data_dir, MODEL_NAME) | ||
model_path = data_dir / "otx_models" / (MODEL_NAME + ".xml") | ||
model = ov.Core().read_model(model_path) | ||
|
||
self.explainer = Explainer( | ||
model=model, | ||
task=Task.CLASSIFICATION, | ||
preprocess_fn=self.preprocess_fn, | ||
explain_mode=ExplainMode.WHITEBOX, | ||
) | ||
|
||
def test_explainer_image(self): | ||
explanation = self.explainer( | ||
self.image, | ||
targets=["person"], | ||
label_names=VOC_NAMES, | ||
colormap=False, | ||
) | ||
assert len(explanation.saliency_map) == 1 | ||
score = self.pointing_game.evaluate([explanation], self.gt_bboxes) | ||
assert score == 1.0 | ||
|
||
def test_explainer_images(self): | ||
images = [self.image, self.image] | ||
explanations = [] | ||
for image in images: | ||
explanation = self.explainer( | ||
image, | ||
targets=["person"], | ||
label_names=VOC_NAMES, | ||
colormap=False, | ||
) | ||
explanations.append(explanation) | ||
dataset_gt_bboxes = self.gt_bboxes * 2 | ||
|
||
score = self.pointing_game.evaluate(explanations, dataset_gt_bboxes) | ||
assert score == 1.0 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import logging | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from openvino_xai.explainer.explanation import Explanation | ||
from openvino_xai.metrics.pointing_game import PointingGame | ||
|
||
|
||
class TestPointingGame: | ||
@pytest.fixture(autouse=True) | ||
def setUp(self): | ||
self.pointing_game = PointingGame() | ||
|
||
def test_pointing_game(self): | ||
saliency_map = np.zeros((3, 3), dtype=np.float32) | ||
saliency_map[1, 1] = 1 | ||
|
||
ground_truth_bbox = [(1, 1, 1, 1)] | ||
score = self.pointing_game.pointing_game(saliency_map, ground_truth_bbox) | ||
assert score == 1 | ||
|
||
ground_truth_bbox = [(0, 0, 0, 0)] | ||
score = self.pointing_game.pointing_game(saliency_map, ground_truth_bbox) | ||
assert score == 0 | ||
|
||
def test_pointing_game_evaluate(self, caplog): | ||
pointing_game = PointingGame() | ||
explanation = Explanation( | ||
label_names=["cat", "dog"], | ||
targets=[0, 1], | ||
saliency_map={0: [[0, 1], [2, 3]], 1: [[0, 0], [0, 1]]}, | ||
) | ||
explanations = [explanation] | ||
|
||
gt_bboxes = [{"cat": [(0, 0, 2, 2)], "dog": [(0, 0, 1, 1)]}] | ||
score = pointing_game.evaluate(explanations, gt_bboxes) | ||
assert score == 1.0 | ||
|
||
# No hit for dog class saliency map, hit for cat class saliency map | ||
gt_bboxes = [{"cat": [(0, 0, 2, 2), (0, 0, 1, 1)], "dog": [(0, 0, 0, 0)]}] | ||
score = pointing_game.evaluate(explanations, gt_bboxes) | ||
assert score == 0.5 | ||
|
||
# No ground truth bboxes for available saliency map classes | ||
gt_bboxes = [{"not-cat": [(0, 0, 2, 2)], "not-dog": [(0, 0, 0, 0)]}] | ||
with caplog.at_level(logging.INFO): | ||
score = pointing_game.evaluate(explanations, gt_bboxes) | ||
assert "Skip pointing game evaluation for this saliency map." in caplog.text | ||
assert score == 0.0 | ||
|
||
# Ground truth bboxes / saliency maps number mismatch | ||
gt_bboxes = [] | ||
with pytest.raises(AssertionError): | ||
score = pointing_game.evaluate(explanations, gt_bboxes) | ||
|
||
# No label names | ||
explanation = Explanation( | ||
label_names=None, | ||
targets=[0, 1], | ||
saliency_map={0: [[0, 1], [2, 3]], 1: [[0, 0], [0, 1]]}, | ||
) | ||
explanations = [explanation] | ||
gt_bboxes = [{"cat": [(0, 0, 2, 2)], "dog": [(0, 0, 1, 1)]}] | ||
with pytest.raises(AssertionError): | ||
score = pointing_game.evaluate(explanations, gt_bboxes) |