Skip to content

Commit

Permalink
Add integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Sep 4, 2024
1 parent b1fc71e commit 36369ad
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions tests/intg/test_classification_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import openvino as ov
import pytest

from openvino_xai import insert_xai
from openvino_xai.common.parameters import Method, Task
from openvino_xai.explainer.explainer import Explainer, ExplainMode
from openvino_xai.explainer.utils import (
Expand Down Expand Up @@ -363,6 +364,7 @@ def test_ovc_model_white_box(self, model_id):
assert explanation is not None
assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1
print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.")
self.clear_cache()

@pytest.mark.parametrize(
"model_id",
Expand Down Expand Up @@ -441,6 +443,61 @@ def test_model_format(self, model_id, explain_mode, model_format):
assert explanation is not None
assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1
print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.")
self.clear_cache()

@pytest.mark.parametrize(
"model_id",
[
"resnet18.a1_in1k",
"efficientnet_b0.ra_in1k",
"vit_tiny_patch16_224.augreg_in21k",
"deit_tiny_patch16_224.fb_in1k",
],
)
def test_torch_insert_xai_with_layer(self, model_id: str):
xai_cfg = {
"resnet18.a1_in1k": ("layer4", Method.RECIPROCAM),
"efficientnet_b0.ra_in1k": ("bn2", Method.RECIPROCAM),
"vit_tiny_patch16_224.augreg_in21k": ("blocks.9.norm1", Method.VITRECIPROCAM),
"deit_tiny_patch16_224.fb_in1k": ("blocks.9.norm1", Method.VITRECIPROCAM),
}

model_dir = self.data_dir / "timm_models" / "converted_models"
model, model_cfg = self.get_timm_model(model_id, model_dir)

image = cv2.imread("tests/assets/cheetah_person.jpg")
image = cv2.resize(image, dsize=model_cfg["input_size"][1:])
image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB)
mean = np.array(model.default_cfg["mean"])
std = np.array(model.default_cfg["std"])
image_norm = (image / 255.0 - mean) / std
image_norm = image_norm.transpose((2, 0, 1)) # HWC -> CHW
image_norm = image_norm[None, :] # CHW -> 1CHW
target_class = self.supported_num_classes[model_cfg["num_classes"]]

xai_model: torch.nn.Module = insert_xai(
model,
task=Task.CLASSIFICATION,
target_layer=xai_cfg[model_id][0],
explain_method=xai_cfg[model_id][1],
)

with torch.no_grad():
xai_model.eval()
xai_output = xai_model(torch.from_numpy(image_norm).float())
xai_logit = xai_output["prediction"]
xai_prob = torch.softmax(xai_logit, dim=-1)
xai_label = xai_prob.argmax(dim=-1)[0]
assert xai_label.item() == target_class
assert xai_prob[0, xai_label].item() > 0.0

saliency_map: np.ndarray = xai_output["saliency_map"].numpy(force=True)
saliency_map = saliency_map.squeeze(0)
assert saliency_map.shape[-1] > 1 and saliency_map.shape[-2] > 1
assert saliency_map.min() < saliency_map.max()
assert saliency_map.dtype == np.uint8

self.clear_cache()

def check_for_saved_map(self, model_id, directory):
for target in self.supported_num_classes.values():
Expand Down

0 comments on commit 36369ad

Please sign in to comment.