diff --git a/openvino_xai/methods/base.py b/openvino_xai/methods/base.py index 9fdf0c55..cb385ac0 100644 --- a/openvino_xai/methods/base.py +++ b/openvino_xai/methods/base.py @@ -64,12 +64,6 @@ def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np. """Saliency map generation.""" -class OVMethod(MethodBase[ov.Model, ov.CompiledModel]): - def load_model(self) -> None: - core = ov.Core() - self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name) - - @dataclass class Prediction: label: int | None = None diff --git a/openvino_xai/methods/black_box/base.py b/openvino_xai/methods/black_box/base.py index 541c27b2..9a311b09 100644 --- a/openvino_xai/methods/black_box/base.py +++ b/openvino_xai/methods/black_box/base.py @@ -8,11 +8,11 @@ import openvino.runtime as ov from openvino_xai.common.utils import IdentityPreprocessFN -from openvino_xai.methods.base import OVMethod +from openvino_xai.methods.base import MethodBase from openvino_xai.methods.black_box.utils import check_classification_output -class BlackBoxXAIMethod(OVMethod): +class BlackBoxXAIMethod(MethodBase[ov.Model, ov.CompiledModel]): """Base class for methods that explain model in Black-Box mode.""" def __init__( @@ -28,7 +28,7 @@ def __init__( def prepare_model(self, load_model: bool = True) -> ov.Model: """Load model prior to inference.""" if load_model: - self.load_model() + self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name) return self._model def get_logits(self, data_preprocessed: np.ndarray) -> np.ndarray: diff --git a/openvino_xai/methods/white_box/activation_map.py b/openvino_xai/methods/white_box/activation_map.py index a5288188..a45fea60 100644 --- a/openvino_xai/methods/white_box/activation_map.py +++ b/openvino_xai/methods/white_box/activation_map.py @@ -39,7 +39,7 @@ def __new__( **kwargs, ): if isinstance(model, torch.nn.Module): - from .torch import ActivationMap as TorchActivationMap + from .torch import TorchActivationMap return TorchActivationMap(model, *args, **kwargs) return super().__new__(cls) diff --git a/openvino_xai/methods/white_box/base.py b/openvino_xai/methods/white_box/base.py index e7929b77..e7a8e10e 100644 --- a/openvino_xai/methods/white_box/base.py +++ b/openvino_xai/methods/white_box/base.py @@ -16,10 +16,10 @@ has_xai, ) from openvino_xai.inserter.inserter import insert_xai_branch_into_model -from openvino_xai.methods.base import OVMethod +from openvino_xai.methods.base import MethodBase -class WhiteBoxMethod(OVMethod): +class WhiteBoxMethod(MethodBase[ov.Model, ov.CompiledModel]): """ Base class for white-box XAI methods. @@ -64,7 +64,7 @@ def prepare_model(self, load_model: bool = True) -> ov.Model: logger.info("Provided IR model already contains XAI branch.") self._model = self._model_ori if load_model: - self.load_model() + self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name) return self._model xai_output_node = self.generate_xai_branch() @@ -72,7 +72,7 @@ def prepare_model(self, load_model: bool = True) -> ov.Model: if not has_xai(self._model): raise RuntimeError("Insertion of the XAI branch into the model was not successful.") if load_model: - self.load_model() + self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name) return self._model @staticmethod diff --git a/openvino_xai/methods/white_box/recipro_cam.py b/openvino_xai/methods/white_box/recipro_cam.py index 413dccde..e5ae6f0b 100644 --- a/openvino_xai/methods/white_box/recipro_cam.py +++ b/openvino_xai/methods/white_box/recipro_cam.py @@ -86,7 +86,7 @@ def __new__( **kwargs, ): if isinstance(model, torch.nn.Module): - from .torch import ReciproCAM as TorchReciproCAM + from .torch import TorchReciproCAM return TorchReciproCAM(model, *args, **kwargs) return super().__new__(cls) @@ -191,7 +191,7 @@ def __new__( **kwargs, ): if isinstance(model, torch.nn.Module): - from .torch import ViTReciproCAM as TorchViTReciproCAM + from .torch import TorchViTReciproCAM return TorchViTReciproCAM(model, *args, **kwargs) return super().__new__(cls) diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index 01c19262..8ac707da 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -14,7 +14,7 @@ from openvino_xai.methods.base import IdentityPreprocessFN, MethodBase -class TorchMethod(MethodBase[torch.nn.Module, torch.nn.Module]): +class TorchWhiteBoxMethod(MethodBase[torch.nn.Module, torch.nn.Module]): """ Base class for Torch-based methods. @@ -103,7 +103,7 @@ def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor: return saliency_map.to(torch.uint8) -class ActivationMap(TorchMethod): +class TorchActivationMap(TorchWhiteBoxMethod): """ActivationMap. Mean of the feature map along the channel dimension.""" def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: @@ -120,7 +120,7 @@ def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tenso } -class ReciproCAM(TorchMethod): +class TorchReciproCAM(TorchWhiteBoxMethod): """Implementation of Recipro-CAM for class-wise saliency map. Recipro-CAM: gradient-free reciprocal class activation map (https://arxiv.org/pdf/2209.14074.pdf) @@ -181,7 +181,7 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w: return mosaic_feature_map -class ViTReciproCAM(ReciproCAM): +class TorchViTReciproCAM(TorchReciproCAM): """Implementation of ViTRecipro-CAM for class-wise saliency map for transformer-based classifiers. ViT-ReciproCAM: Gradient and Attention-Free Visual Explanations for Vision Transformer diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index 8032768d..fe5be2f5 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -11,10 +11,10 @@ from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME, has_xai from openvino_xai.methods.white_box.torch import ( - ActivationMap, - ReciproCAM, + TorchActivationMap, TorchMethod, - ViTReciproCAM, + TorchReciproCAM, + TorchViTReciproCAM, ) @@ -103,7 +103,7 @@ def test_activationmap() -> None: batch_size = 2 num_classes = 3 model = DummyCNN(num_classes=num_classes) - method = ActivationMap(model=model, target_layer="feature") + method = TorchActivationMap(model=model, target_layer="feature") model_xai = method.prepare_model() assert has_xai(model_xai) data = np.random.rand(batch_size, 3, 5, 5) @@ -121,7 +121,7 @@ def test_reciprocam(optimize_gap: bool) -> None: batch_size = 2 num_classes = 3 model = DummyCNN(num_classes=num_classes) - method = ReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap) + method = TorchReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap) model_xai = method.prepare_model() assert has_xai(model_xai) data = np.random.rand(batch_size, 4, 5, 5) @@ -140,7 +140,9 @@ def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None: batch_size = 2 num_classes = 3 model = DummyVIT(num_classes=num_classes) - method = ViTReciproCAM(model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token) + method = TorchViTReciproCAM( + model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token + ) model_xai = method.prepare_model() assert has_xai(model_xai) data = np.random.rand(batch_size, 4, 5, 5)