Skip to content

Commit

Permalink
Support milti class ann and several bb per image
Browse files Browse the repository at this point in the history
  • Loading branch information
GalyaZalesskaya committed Aug 19, 2024
1 parent fc55e93 commit 5d966cb
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 58 deletions.
82 changes: 56 additions & 26 deletions openvino_xai/metrics/pointing_game.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,77 @@
from typing import List, Tuple
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Tuple

import numpy as np

from openvino_xai.common.utils import logger
from openvino_xai.explainer.explanation import Explanation


class PointingGame:
@staticmethod
def pointing_game(saliency_map: np.ndarray, gt_bbox: Tuple[int, int, int, int]) -> bool:
def pointing_game(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int, int, int]]) -> bool:
"""
Implements the Pointing Game metric using bounding boxes. Returns a boolean indicating
if any of the most salient point falls within the ground truth bounding box.
Implements the Pointing Game metric using a saliency map and bounding boxes of the same image and class.
Returns a boolean indicating if any of the most salient points fall within the ground truth bounding boxes.
:param saliency_map: A 2D numpy array representing the saliency map for the image.
:type saliency_map: np.ndarray
:param gt_bbox: A tuple (x, y, w, h) representing the bounding box of the ground truth object.
:type gt_bbox: Tuple[int, int, int, int]
"""
# TODO: Support a case with multiple bounding boxes for one imege
x, y, w, h = gt_bbox
:param image_gt_bboxes: A list of tuples (x, y, w, h) representing the bounding boxes of the ground truth objects.
:type image_gt_bboxes: List[Tuple[int, int, int, int]]
:return: True if any of the most salient points fall within any of the ground truth bounding boxes, False otherwise.
:rtype: bool
"""
# Find the most salient points in the saliency map
max_indices = np.argwhere(saliency_map == np.max(saliency_map))

for max_point_y, max_point_x in max_indices:
# Check if this point is within the ground truth bounding box
if x <= max_point_x <= x + w and y <= max_point_y <= y + h:
return True
# If multiple bounding boxes are available for one image
for x, y, w, h in image_gt_bboxes:
for max_point_y, max_point_x in max_indices:
# Check if this point is within the ground truth bounding box
if x <= max_point_x <= x + w and y <= max_point_y <= y + h:
return True
return False

def evaluate(self, saliency_maps: List[np.ndarray], gt_bboxes: List[Tuple[int, int, int, int]]) -> float:
def evaluate(
self, explanations: List[Explanation], gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]]
) -> float:
"""
Evaluates the Pointing Game metric over a set of images.
Evaluates the Pointing Game metric over a set of images. Skips saliency maps if the gt bboxes for this class are absent.
:param explanations: A list of explanations for each image.
:type explanations: List[Explanation]
:param gt_bboxes: A list of dictionaries {label_name: lists of bounding boxes} for each image.
:type gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]]
:param saliency_maps: A list of 2D numpy arrays representing the saliency maps.
:type saliency_maps: List[np.ndarray]
:param gt_bboxes: A list of bounding box of the ground truth objects for each image.
:type gt_bboxes: List[Tuple[int, int, int, int]]
:return: Pointing game score over a list of images
:rtype: float
"""
assert len(saliency_maps) == len(

assert len(explanations) == len(
gt_bboxes
), "Number of saliency maps and ground truth bounding boxes must match."
), "Number of explanations and ground truth bounding boxes must match and equal to number of images."

hits = 0
num_sal_maps = 0
for explanation, image_gt_bboxes in zip(explanations, gt_bboxes):
label_names = explanation.label_names
assert label_names is not None, "Label names are required for pointing game evaluation."

for class_idx, class_sal_map in explanation.saliency_map.items():
label_name = label_names[int(class_idx)]

if label_name not in image_gt_bboxes:
logger.info(
f"No ground-truth bbox for {label_name} saliency map. "
f"Skip pointing game evaluation for this saliency map."
)
continue

class_gt_bboxes = image_gt_bboxes[label_name]
hits += self.pointing_game(class_sal_map, class_gt_bboxes)
num_sal_maps += 1

hits = sum(
[self.pointing_game(s_map, image_gt_bboxes) for s_map, image_gt_bboxes in zip(saliency_maps, gt_bboxes)]
)
score = hits / len(saliency_maps)
return score
return hits / num_sal_maps if num_sal_maps > 0 else 0.0
72 changes: 48 additions & 24 deletions tests/regression/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import json
from typing import Dict, List, Tuple

import cv2
import openvino as ov
Expand All @@ -15,34 +16,51 @@
from tests.unit.explanation.test_explanation_utils import VOC_NAMES

MODEL_NAME = "mlc_mobilenetv3_large_voc"
IMAGE_PATH = "tests/assets/cheetah_person.jpg"
COCO_ANN_PATH = "tests/assets/cheetah_person_coco.json"


def load_gt_bboxes(class_name="person"):
with open("tests/assets/cheetah_person_coco.json", "r") as f:
def load_gt_bboxes(json_coco_path: str) -> List[Dict[str, List[Tuple[int, int, int, int]]]]:
"""
Loads ground truth bounding boxes from a COCO format JSON file.
Returns a list of dictionaries, where each dictionary corresponds to an image.
The key is the label name and the value is a list of bounding boxes for certain image.
"""

with open(json_coco_path, "r") as f:
coco_anns = json.load(f)

category_id = [category["id"] for category in coco_anns["categories"] if category["name"] == class_name]
category_id = category_id[0]
result = {}
category_id_to_name = {category["id"]: category["name"] for category in coco_anns["categories"]}

for annotation in coco_anns["annotations"]:
image_id = annotation["image_id"]
category_id = annotation["category_id"]
bbox = annotation["bbox"]

category_name = category_id_to_name[category_id]
if image_id not in result:
result[image_id] = {}
if category_name not in result[image_id]:
result[image_id][category_name] = []

category_gt_bboxes = [
annotation["bbox"] for annotation in coco_anns["annotations"] if annotation["category_id"] == category_id
]
return category_gt_bboxes
result[image_id][category_name].append(bbox)

return list(result.values())


class TestDummyRegression:
image = cv2.imread("tests/assets/cheetah_person.jpg")
image = cv2.imread(IMAGE_PATH)
gt_bboxes = load_gt_bboxes(COCO_ANN_PATH)
pointing_game = PointingGame()

preprocess_fn = get_preprocess_fn(
change_channel_order=True,
input_size=(224, 224),
hwc_to_chw=True,
)

gt_bboxes = load_gt_bboxes()
pointing_game = PointingGame()
steps = 10

@pytest.fixture(autouse=True)
def setup(self, fxt_data_root):
data_dir = fxt_data_root
Expand All @@ -65,26 +83,32 @@ def test_explainer_image(self):
colormap=False,
)
assert len(explanation.saliency_map) == 1
score = self.pointing_game.evaluate([explanation], self.gt_bboxes)
assert score == 1.0

# For now, assume that there's only one class
# TODO: support multiple classes
saliency_maps = list(explanation.saliency_map.values())
score = self.pointing_game.evaluate(saliency_maps, self.gt_bboxes)
assert score > 0.5
explanation = self.explainer(
self.image,
targets=["cat"],
label_names=VOC_NAMES,
colormap=False,
)
assert len(explanation.saliency_map) == 1
score = self.pointing_game.evaluate([explanation], self.gt_bboxes)
# No gt box for "cat" class
assert score == 0.0

def test_explainer_images(self):
# TODO support multiple classes
images = [self.image, self.image]
saliency_maps = []
explanations = []
for image in images:
explanation = self.explainer(
image,
targets=["person"],
label_names=VOC_NAMES,
colormap=False,
)
saliency_map = list(explanation.saliency_map.values())[0]
saliency_maps.append(saliency_map)
explanations.append(explanation)
dataset_gt_bboxes = self.gt_bboxes * 2

score = self.pointing_game.evaluate(saliency_maps, self.gt_bboxes * 2)
assert score > 0.5
score = self.pointing_game.evaluate(explanations, dataset_gt_bboxes)
assert score == 1.0
47 changes: 39 additions & 8 deletions tests/unit/metrics/test_pointing_game.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging

import numpy as np
import pytest

from openvino_xai.explainer.explanation import Explanation
from openvino_xai.metrics.pointing_game import PointingGame


Expand All @@ -13,19 +16,47 @@ def test_pointing_game(self):
saliency_map = np.zeros((3, 3), dtype=np.float32)
saliency_map[1, 1] = 1

ground_truth_bbox = (1, 1, 1, 1)
ground_truth_bbox = [(1, 1, 1, 1)]
score = self.pointing_game.pointing_game(saliency_map, ground_truth_bbox)
assert score == 1

ground_truth_bbox = (0, 0, 0, 0)
ground_truth_bbox = [(0, 0, 0, 0)]
score = self.pointing_game.pointing_game(saliency_map, ground_truth_bbox)
assert score == 0

def test_pointing_game_evaluate(self):
saliency_map = np.zeros((3, 3), dtype=np.float32)
saliency_map[1, 1] = 1
def test_pointing_game_evaluate(self, caplog):
pointing_game = PointingGame()
explanation = Explanation(
label_names=["cat", "dog"], saliency_map={0: [[0, 1], [2, 3]], 1: [[0, 0], [0, 1]]}, targets=[0, 1]
)
explanations = [explanation]

saliency_maps = [saliency_map, saliency_map]
ground_truth_bboxes = [(0, 0, 0, 0), (1, 1, 1, 1)]
score = self.pointing_game.evaluate(saliency_maps, ground_truth_bboxes)
gt_bboxes = [{"cat": [(0, 0, 2, 2)], "dog": [(0, 0, 1, 1)]}]
score = pointing_game.evaluate(explanations, gt_bboxes)
assert score == 1.0

# No hit for dog class saliency map, hit for cat class saliency map
gt_bboxes = [{"cat": [(0, 0, 2, 2), (0, 0, 1, 1)], "dog": [(0, 0, 0, 0)]}]
score = pointing_game.evaluate(explanations, gt_bboxes)
assert score == 0.5

# No ground truth bboxes for available saliency map classes
gt_bboxes = [{"not-cat": [(0, 0, 2, 2)], "not-dog": [(0, 0, 0, 0)]}]
with caplog.at_level(logging.INFO):
score = pointing_game.evaluate(explanations, gt_bboxes)
assert "Skip pointing game evaluation for this saliency map." in caplog.text
assert score == 0.0

# Ground truth bboxes / saliency maps number mismatch
gt_bboxes = []
with pytest.raises(AssertionError):
score = pointing_game.evaluate(explanations, gt_bboxes)

# No label names
explanation = Explanation(
label_names=None, saliency_map={0: [[0, 1], [2, 3]], 1: [[0, 0], [0, 1]]}, targets=[0, 1]
)
explanations = [explanation]
gt_bboxes = [{"cat": [(0, 0, 2, 2)], "dog": [(0, 0, 1, 1)]}]
with pytest.raises(AssertionError):
score = pointing_game.evaluate(explanations, gt_bboxes)

0 comments on commit 5d966cb

Please sign in to comment.