From 36369ad5b5224e31ed0a0975c7d1f74ed76a5416 Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Wed, 4 Sep 2024 14:23:00 +0900 Subject: [PATCH] Add integration test --- tests/intg/test_classification_timm.py | 57 ++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/intg/test_classification_timm.py b/tests/intg/test_classification_timm.py index d2fb3d47..d7667d61 100644 --- a/tests/intg/test_classification_timm.py +++ b/tests/intg/test_classification_timm.py @@ -11,6 +11,7 @@ import openvino as ov import pytest +from openvino_xai import insert_xai from openvino_xai.common.parameters import Method, Task from openvino_xai.explainer.explainer import Explainer, ExplainMode from openvino_xai.explainer.utils import ( @@ -363,6 +364,7 @@ def test_ovc_model_white_box(self, model_id): assert explanation is not None assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1 print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.") + self.clear_cache() @pytest.mark.parametrize( "model_id", @@ -441,6 +443,61 @@ def test_model_format(self, model_id, explain_mode, model_format): assert explanation is not None assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1 print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.") + self.clear_cache() + + @pytest.mark.parametrize( + "model_id", + [ + "resnet18.a1_in1k", + "efficientnet_b0.ra_in1k", + "vit_tiny_patch16_224.augreg_in21k", + "deit_tiny_patch16_224.fb_in1k", + ], + ) + def test_torch_insert_xai_with_layer(self, model_id: str): + xai_cfg = { + "resnet18.a1_in1k": ("layer4", Method.RECIPROCAM), + "efficientnet_b0.ra_in1k": ("bn2", Method.RECIPROCAM), + "vit_tiny_patch16_224.augreg_in21k": ("blocks.9.norm1", Method.VITRECIPROCAM), + "deit_tiny_patch16_224.fb_in1k": ("blocks.9.norm1", Method.VITRECIPROCAM), + } + + model_dir = self.data_dir / "timm_models" / "converted_models" + model, model_cfg = self.get_timm_model(model_id, model_dir) + + image = cv2.imread("tests/assets/cheetah_person.jpg") + image = cv2.resize(image, dsize=model_cfg["input_size"][1:]) + image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB) + mean = np.array(model.default_cfg["mean"]) + std = np.array(model.default_cfg["std"]) + image_norm = (image / 255.0 - mean) / std + image_norm = image_norm.transpose((2, 0, 1)) # HWC -> CHW + image_norm = image_norm[None, :] # CHW -> 1CHW + target_class = self.supported_num_classes[model_cfg["num_classes"]] + + xai_model: torch.nn.Module = insert_xai( + model, + task=Task.CLASSIFICATION, + target_layer=xai_cfg[model_id][0], + explain_method=xai_cfg[model_id][1], + ) + + with torch.no_grad(): + xai_model.eval() + xai_output = xai_model(torch.from_numpy(image_norm).float()) + xai_logit = xai_output["prediction"] + xai_prob = torch.softmax(xai_logit, dim=-1) + xai_label = xai_prob.argmax(dim=-1)[0] + assert xai_label.item() == target_class + assert xai_prob[0, xai_label].item() > 0.0 + + saliency_map: np.ndarray = xai_output["saliency_map"].numpy(force=True) + saliency_map = saliency_map.squeeze(0) + assert saliency_map.shape[-1] > 1 and saliency_map.shape[-2] > 1 + assert saliency_map.min() < saliency_map.max() + assert saliency_map.dtype == np.uint8 + + self.clear_cache() def check_for_saved_map(self, model_id, directory): for target in self.supported_num_classes.values():