Skip to content

Commit

Permalink
Try CNN -> ViT assumption for IR insertion (#48)
Browse files Browse the repository at this point in the history
* Update change log

* Try CNN -> ViT assumption for IR insertion

* Fix pre-commit

* Enable botnet as CNN

* Enable botnext & edgenext as CNN

* Enable convmixer as CNN

* Guess WB method in func tests

* Update change log

* RISE -> AISE in tests as default method

---------

Co-authored-by: Galina Zalesskaya <[email protected]>
  • Loading branch information
goodsong81 and GalyaZalesskaya authored Aug 2, 2024
1 parent c497c01 commit 2967389
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 80 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@

### Summary

* Support OpenVINO IR (.xml) / ONNX (.onnx) model file for `Explainer` model
* Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models.

### What's Changed

* Use OVC converted models in func tests by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/44
* Update CodeCov action by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/46
* Refactor OpenVINO imports by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/45
* Support OV IR / ONNX model file for Explainer by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/47
* Try CNN -> ViT assumption for IR insertion by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/48
* Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models by @negvet in https://github.com/openvinotoolkit/openvino_xai/pull/49

### Known Issues

* OpenVINO IR branch insertion not working for models converted directly from torch models in https://github.com/openvinotoolkit/openvino_xai/issues/26
* Runtime error from ONNX / OpenVINO IR models while conversion or inference for XAI in https://github.com/openvinotoolkit/openvino_xai/issues/29
* Models not supported by white box XAI methods in https://github.com/openvinotoolkit/openvino_xai/issues/30

### New Contributors

*
* N/A

---

Expand Down
23 changes: 15 additions & 8 deletions openvino_xai/methods/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,21 @@ def create_classification_method(

if explain_method is None or explain_method == Method.RECIPROCAM:
logger.info("Using ReciproCAM method (for CNNs).")
return ReciproCAM(
model,
preprocess_fn,
target_layer,
embed_scaling,
device_name,
**kwargs,
)
try:
return ReciproCAM(
model,
preprocess_fn,
target_layer,
embed_scaling,
device_name,
**kwargs,
)
except Exception as e:
if explain_method is None:
logger.info(f"Not successfull due to '{e}'. Trying another methods.")
explain_method = Method.VITRECIPROCAM
else:
raise e
if explain_method == Method.VITRECIPROCAM:
logger.info("Using ViTReciproCAM method (for vision transformers).")
return ViTReciproCAM(
Expand Down
82 changes: 14 additions & 68 deletions tests/func/test_classification_timm_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,48 +29,6 @@

TEST_MODELS = timm.list_models(pretrained=True)

CNN_MODELS = [
"bat_resnext",
"convnext",
"cs3",
"cs3darknet",
"darknet",
"densenet",
"dla",
"dpn",
"efficientnet",
"ese_vovnet",
"fbnet",
"gernet",
"ghostnet",
"hardcorenas",
"hrnet",
"inception",
"lcnet",
"legacy_",
"mixnet",
"mnasnet",
"mobilenet",
"nasnet",
"regnet",
"repvgg",
"res2net",
"res2next",
"resnest",
"resnet",
"resnext",
"rexnet",
"selecsls",
"semnasnet",
"senet",
"seresnext",
"spnasnet",
"tinynet",
"tresnet",
"vgg",
"xception",
]

SUPPORTED_BUT_FAILED_BY_BB_MODELS = {}

NOT_SUPPORTED_BY_BB_MODELS = {
Expand All @@ -82,7 +40,7 @@
"dm_nfnet": "openvino._pyopenvino.GeneralFailure: Check 'false' failed at src/frontends/onnx/frontend/src/frontend.cpp:144",
"eca_nfnet": "openvino._pyopenvino.GeneralFailure: Check 'false' failed at src/frontends/onnx/frontend/src/frontend.cpp:144",
"eva_giant": "RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library.",
"halo": "torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.",
# "halo": "torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.",
"nf_regnet": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.",
"nf_resnet": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.",
"nfnet_l0": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.",
Expand Down Expand Up @@ -110,6 +68,7 @@
**NOT_SUPPORTED_BY_BB_MODELS,
# Killed on WB
"beit_large_patch16_512": "Failed to allocate 94652825600 bytes of memory",
"convmixer_1536_20": "OOM Killed",
"eva_large_patch14_336": "OOM Killed",
"eva02_base_patch14_448": "OOM Killed",
"eva02_large_patch14_448": "OOM Killed",
Expand All @@ -127,32 +86,23 @@
"xcit_small_12_p8_384": "OOM Killed",
"xcit_small_24_p8_384": "OOM Killed",
# Not expected to work for now
"botnet26t_256": "Only two outputs of the between block Add node supported, but got 1",
"caformer": "One (and only one) of the nodes has to be Add type. But got MVN and Multiply.",
"cait_": "Cannot create an empty Constant. Please provide valid data.",
"coat_": "Only two outputs of the between block Add node supported, but got 1.",
"coatn": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"convmixer": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"crossvit": "One (and only one) of the nodes has to be Add type. But got StridedSlice and StridedSlice.",
"davit": "Only two outputs of the between block Add node supported, but got 1.",
"eca_botnext": "Only two outputs of the between block Add node supported, but got 1.",
"edgenext": "Only two outputs of the between block Add node supported, but got 1",
"efficientformer": "Cannot find output backbone_node in auto mode.",
"focalnet": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"gcvit": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"levit_": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"maxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"maxxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"mobilevitv2": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"nest_": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"poolformer": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"sebotnet": "Only two outputs of the between block Add node supported, but got 1.",
# work in CNN mode -> "davit": "Only two outputs of the between block Add node supported, but got 1.",
# work in CNN mode -> "efficientformer": "Cannot find output backbone_node in auto mode.",
# work in CNN mode -> "focalnet": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "gcvit": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"levit_": "Check 'TRShape::merge_into(output_shape, in_copy)' failed",
# work in CNN mode -> "maxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "maxxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "mobilevitv2": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "nest_": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "poolformer": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"sequencer2d": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"tnt_s_patch16_224": "Only two outputs of the between block Add node supported, but got 1.",
"tresnet": "Batch shape of the output should be dynamic, but it is static.",
"twins": "One (and only one) of the nodes has to be Add type. But got ShapeOf and Transpose.",
"visformer": "Cannot find output backbone_node in auto mode, please provide target_layer",
"vit_relpos_base_patch32_plus_rpn_256": "Check 'TRShape::merge_into(output_shape, in_copy)' failed",
# work in CNN mode -> "visformer": "Cannot find output backbone_node in auto mode, please provide target_layer",
"vit_relpos_medium_patch16_rpn_224": "ValueError in openvino_xai/methods/white_box/recipro_cam.py:215",
}

Expand Down Expand Up @@ -184,11 +134,7 @@ def test_classification_white_box(self, model_id, dump_maps=False):
if failed_model in model_id:
pytest.xfail(reason=SUPPORTED_BUT_FAILED_BY_WB_MODELS[failed_model])

explain_method = Method.VITRECIPROCAM
for cnn_model in CNN_MODELS:
if cnn_model in model_id:
explain_method = Method.RECIPROCAM
break
explain_method = None

timm_model, model_cfg = self.get_timm_model(model_id)
input_size = list(timm_model.default_cfg["input_size"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import openvino as ov
import pytest
from pytest_mock import MockerFixture

from openvino_xai.common.parameters import Method, Task
from openvino_xai.common.utils import retrieve_otx_model
from openvino_xai.explainer.utils import get_preprocess_fn
from openvino_xai.methods.factory import WhiteBoxMethodFactory
from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn
from openvino_xai.methods.black_box.aise import AISE
from openvino_xai.methods.factory import BlackBoxMethodFactory, WhiteBoxMethodFactory
from openvino_xai.methods.white_box.activation_map import ActivationMap
from openvino_xai.methods.white_box.det_class_probability_map import (
DetClassProbabilityMap,
Expand Down Expand Up @@ -75,6 +77,40 @@ def test_create_wb_cls_vit_method(fxt_data_root: Path):
assert isinstance(explain_method, ViTReciproCAM)


def test_create_wb_cls_guess_method(mocker: MockerFixture):
model = mocker.MagicMock()
# method=None -> ReciproCAM fail -> ViTReciproCAM
recipro_cam = mocker.patch("openvino_xai.methods.factory.ReciproCAM", side_effect=Exception("DUMMY REASON"))
vit_recipro_cam = mocker.patch("openvino_xai.methods.factory.ViTReciproCAM")
explain_method = WhiteBoxMethodFactory.create_method(
task=Task.CLASSIFICATION,
model=model,
preprocess_fn=PREPROCESS_FN,
explain_method=None,
)
vit_recipro_cam.assert_called()
# method=ReciproCAM -> ReciproCAM fail -> Exception
recipro_cam = mocker.patch("openvino_xai.methods.factory.ReciproCAM", side_effect=Exception("DUMMY REASON"))
vit_recipro_cam = mocker.patch("openvino_xai.methods.factory.ViTReciproCAM")
with pytest.raises(Exception) as exc_info:
explain_method = WhiteBoxMethodFactory.create_method(
task=Task.CLASSIFICATION,
model=model,
preprocess_fn=PREPROCESS_FN,
explain_method=Method.RECIPROCAM,
)
vit_recipro_cam.assert_not_called()
assert str(exc_info.value) == "DUMMY REASON"


def test_create_bb_cls_vit_method(fxt_data_root: Path):
retrieve_otx_model(fxt_data_root, VIT_MODEL)
model_path = fxt_data_root / "otx_models" / (VIT_MODEL + ".xml")
model_vit = ov.Core().read_model(model_path)
explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model_vit, get_postprocess_fn())
assert isinstance(explain_method, AISE)


def test_create_wb_det_cnn_method(fxt_data_root: Path):
retrieve_otx_model(fxt_data_root, DEFAULT_DET_MODEL)
model_path = fxt_data_root / "otx_models" / (DEFAULT_DET_MODEL + ".xml")
Expand Down

0 comments on commit 2967389

Please sign in to comment.