Skip to content

Commit

Permalink
Fix bugs (#69)
Browse files Browse the repository at this point in the history
* adaptive font_size

* update params

* logging

* font_face -> 2

* adapt text_height

* define offset for text

* gray cmap

* improve gray cmap

* fix detection overlay text

* fix bhwc

* fix save

* minor

* test plot

* fix fp rounding error
  • Loading branch information
negvet authored Sep 17, 2024
1 parent 41ddb11 commit 10a0d96
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 16 deletions.
1 change: 1 addition & 0 deletions openvino_xai/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ExplainMode(Enum):
Contains the following values:
WHITEBOX - The model is explained in white box mode, i.e. XAI branch is getting inserted into the model graph.
BLACKBOX - The model is explained in black box model.
AUTO - The model is explained in the white-box mode first, if fails - black-box mode will run.
"""

WHITEBOX = "whitebox"
Expand Down
9 changes: 7 additions & 2 deletions openvino_xai/explainer/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def save(
map_to_save = cv2.cvtColor(map_to_save, code=cv2.COLOR_RGB2BGR)
if isinstance(target_idx, str):
target_name = "activation_map"
elif self.label_names and isinstance(target_idx, np.int64) and self.task != Task.DETECTION:
elif self.label_names and isinstance(target_idx, (int, np.int64)) and self.task != Task.DETECTION:
target_name = self.label_names[target_idx]
else:
target_name = str(target_idx)
Expand Down Expand Up @@ -261,7 +261,12 @@ def _plot_matplotlib(self, checked_targets: list[int | str], num_cols: int) -> N

map_to_plot = self.saliency_map[target_index]

axes[i].imshow(map_to_plot)
if map_to_plot.ndim == 3:
axes[i].imshow(map_to_plot)
elif map_to_plot.ndim == 2:
axes[i].imshow(map_to_plot, cmap="gray")
else:
raise ValueError(f"Saliency map expected to be 3 or 2-dimensional, but got {map_to_plot.ndim}.")
axes[i].axis("off") # Hide the axis
axes[i].set_title(f"Class {label_name}")

Expand Down
47 changes: 36 additions & 11 deletions openvino_xai/explainer/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,33 +174,34 @@ def visualize(
# Convert back to dict
return self._update_explanation_with_processed_sal_map(explanation, saliency_map_np, indices_to_return)

@staticmethod
def _put_classification_info(
self,
saliency_map_np: np.ndarray,
indices: List[int],
label_names: List[str] | None,
predictions: Dict[int, Prediction] | None,
) -> None:
corner_location = 3, 17
offset = 3
for smap, target_index in zip(range(len(saliency_map_np)), indices):
label = label_names[target_index] if label_names else str(target_index)
if predictions and target_index in predictions:
score = predictions[target_index].score
if score:
label = f"{label}|{score:.2f}"

font_scale, text_height = self._fit_text_to_image(label, offset, saliency_map_np[smap].shape[1])
cv2.putText(
saliency_map_np[smap],
label,
org=corner_location,
fontFace=1,
fontScale=1.3,
org=(offset, text_height + offset),
fontFace=2,
fontScale=font_scale,
color=(255, 0, 0),
thickness=2,
thickness=1,
)

@staticmethod
def _put_detection_info(
self,
saliency_map_np: np.ndarray,
indices: List[int],
label_names: List[str] | None,
Expand All @@ -209,6 +210,7 @@ def _put_detection_info(
if not predictions:
return

offset = 7
for smap, target_index in zip(range(len(saliency_map_np)), indices):
saliency_map = saliency_map_np[smap]
label_index = predictions[target_index].label
Expand All @@ -220,17 +222,40 @@ def _put_detection_info(

label = label_names[label_index] if label_names else label_index
label_score = f"{label}|{score:.2f}"
box_location = int(x1), int(y1 - 5)

font_scale, _ = self._fit_text_to_image(label_score, x1, saliency_map.shape[1])
box_location = x1, y1 - offset
cv2.putText(
saliency_map,
label_score,
org=box_location,
fontFace=1,
fontScale=1.3,
fontFace=2,
fontScale=font_scale,
color=(255, 0, 0),
thickness=2,
thickness=1,
)

@staticmethod
def _fit_text_to_image(
text: str,
x_start: int,
image_width: int,
font_scale: float = 1.0,
thickness: int = 1,
) -> Tuple[float, int]:
font_face = 2
max_width = image_width - 5
while True:
text_size, _ = cv2.getTextSize(text, font_face, font_scale, thickness)
text_width, text_height = text_size

if x_start + text_width <= max_width:
return font_scale, text_height

font_scale -= 0.1
if abs(font_scale - 0.1) < 0.001:
return font_scale, text_height

@staticmethod
def _apply_scaling(explanation: Explanation, saliency_map_np: np.ndarray) -> np.ndarray:
if explanation.layout not in GRAY_LAYOUTS:
Expand Down
4 changes: 3 additions & 1 deletion openvino_xai/methods/black_box/aise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import openvino.runtime as ov
from scipy.optimize import direct

from openvino_xai.common.utils import IdentityPreprocessFN
from openvino_xai.common.utils import IdentityPreprocessFN, is_bhwc_layout
from openvino_xai.methods.black_box.base import BlackBoxXAIMethod


Expand Down Expand Up @@ -92,6 +92,8 @@ def _objective_function(self, args) -> float:

kernel_mask = self._mask_generator.generate_kernel_mask(kernel_params)
kernel_mask = np.clip(kernel_mask, 0, 1)
if is_bhwc_layout(self.data_preprocessed):
kernel_mask = np.expand_dims(kernel_mask, 2)

pred_loss_preserve = 0.0
if self.preservation:
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/explainer/test_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
from tests.unit.explainer.test_explanation_utils import VOC_NAMES

SALIENCY_MAPS = (np.random.rand(1, 20, 5, 5) * 255).astype(np.uint8)
SALIENCY_MAPS_DICT = {
0: (np.random.rand(5, 5, 3) * 255).astype(np.uint8),
2: (np.random.rand(5, 5, 3) * 255).astype(np.uint8),
}
SALIENCY_MAPS_DICT_EXCEPTION = {
0: (np.random.rand(5, 5, 3, 2) * 255).astype(np.uint8),
2: (np.random.rand(5, 5, 3, 2) * 255).astype(np.uint8),
}
SALIENCY_MAPS_IMAGE = (np.random.rand(1, 5, 5) * 255).astype(np.uint8)


Expand Down Expand Up @@ -106,7 +114,7 @@ def test_plot(self, mocker, caplog):
# Update the num columns for the matplotlib visualization grid
explanation.plot(backend="matplotlib", num_columns=1)

# Class index that is not in saliency maps will be ommitted with message
# Class index that is not in saliency maps will be omitted with message
with caplog.at_level(logging.INFO):
explanation.plot([0, 3], backend="matplotlib")
assert "Provided class index 3 is not available among saliency maps." in caplog.text
Expand All @@ -123,3 +131,13 @@ def test_plot(self, mocker, caplog):
# Plot activation map
explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_IMAGE, label_names=None)
explanation.plot()

# Plot colored map
explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_DICT, label_names=None)
explanation.plot()

# Plot wrong map shape
with pytest.raises(Exception) as exc_info:
explanation = self._get_explanation(saliency_maps=SALIENCY_MAPS_DICT_EXCEPTION, label_names=None)
explanation.plot()
assert str(exc_info.value) == "Saliency map expected to be 3 or 2-dimensional, but got 4."
8 changes: 7 additions & 1 deletion tests/unit/explainer/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from openvino_xai.explainer.visualizer import Visualizer, colormap, overlay, resize
from openvino_xai.methods.base import Prediction

ORIGINAL_INPUT_IMAGE = [
np.ones((100, 100, 3)),
np.ones((10, 10, 3)),
]

SALIENCY_MAPS = [
(np.random.rand(1, 5, 5) * 255).astype(np.uint8),
(np.random.rand(1, 2, 5, 5) * 255).astype(np.uint8),
Expand Down Expand Up @@ -97,6 +102,7 @@ def test_overlay():


class TestVisualizer:
@pytest.mark.parametrize("original_input_image", ORIGINAL_INPUT_IMAGE)
@pytest.mark.parametrize("saliency_maps", SALIENCY_MAPS)
@pytest.mark.parametrize("explain_all_classes", EXPLAIN_ALL_CLASSES)
@pytest.mark.parametrize("task", [Task.CLASSIFICATION, Task.DETECTION])
Expand All @@ -107,6 +113,7 @@ class TestVisualizer:
@pytest.mark.parametrize("overlay_weight", [0.5, 0.3])
def test_visualizer(
self,
original_input_image,
saliency_maps,
explain_all_classes,
task,
Expand All @@ -124,7 +131,6 @@ def test_visualizer(
explanation = Explanation(saliency_maps, targets=explain_targets, task=Task.CLASSIFICATION)

raw_sal_map_dims = len(explanation.shape)
original_input_image = np.ones((20, 20, 3))
visualizer = Visualizer()
explanation = visualizer(
explanation=explanation,
Expand Down

0 comments on commit 10a0d96

Please sign in to comment.