From 3fa9b04b2e9f9e16dd98b276cfb2cae1c00d0a7a Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Thu, 5 Sep 2024 13:18:53 +0900 Subject: [PATCH] Refactor torch method class hierarchy --- openvino_xai/methods/base.py | 6 --- openvino_xai/methods/black_box/base.py | 6 +-- .../methods/white_box/activation_map.py | 2 +- openvino_xai/methods/white_box/base.py | 8 ++-- openvino_xai/methods/white_box/recipro_cam.py | 4 +- openvino_xai/methods/white_box/torch.py | 19 ++++++--- tests/unit/methods/white_box/test_torch.py | 41 ++++++++++++++----- 7 files changed, 53 insertions(+), 33 deletions(-) diff --git a/openvino_xai/methods/base.py b/openvino_xai/methods/base.py index 9fdf0c55..cb385ac0 100644 --- a/openvino_xai/methods/base.py +++ b/openvino_xai/methods/base.py @@ -64,12 +64,6 @@ def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np. """Saliency map generation.""" -class OVMethod(MethodBase[ov.Model, ov.CompiledModel]): - def load_model(self) -> None: - core = ov.Core() - self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name) - - @dataclass class Prediction: label: int | None = None diff --git a/openvino_xai/methods/black_box/base.py b/openvino_xai/methods/black_box/base.py index 541c27b2..9a311b09 100644 --- a/openvino_xai/methods/black_box/base.py +++ b/openvino_xai/methods/black_box/base.py @@ -8,11 +8,11 @@ import openvino.runtime as ov from openvino_xai.common.utils import IdentityPreprocessFN -from openvino_xai.methods.base import OVMethod +from openvino_xai.methods.base import MethodBase from openvino_xai.methods.black_box.utils import check_classification_output -class BlackBoxXAIMethod(OVMethod): +class BlackBoxXAIMethod(MethodBase[ov.Model, ov.CompiledModel]): """Base class for methods that explain model in Black-Box mode.""" def __init__( @@ -28,7 +28,7 @@ def __init__( def prepare_model(self, load_model: bool = True) -> ov.Model: """Load model prior to inference.""" if load_model: - self.load_model() + self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name) return self._model def get_logits(self, data_preprocessed: np.ndarray) -> np.ndarray: diff --git a/openvino_xai/methods/white_box/activation_map.py b/openvino_xai/methods/white_box/activation_map.py index a5288188..a45fea60 100644 --- a/openvino_xai/methods/white_box/activation_map.py +++ b/openvino_xai/methods/white_box/activation_map.py @@ -39,7 +39,7 @@ def __new__( **kwargs, ): if isinstance(model, torch.nn.Module): - from .torch import ActivationMap as TorchActivationMap + from .torch import TorchActivationMap return TorchActivationMap(model, *args, **kwargs) return super().__new__(cls) diff --git a/openvino_xai/methods/white_box/base.py b/openvino_xai/methods/white_box/base.py index e7929b77..e7a8e10e 100644 --- a/openvino_xai/methods/white_box/base.py +++ b/openvino_xai/methods/white_box/base.py @@ -16,10 +16,10 @@ has_xai, ) from openvino_xai.inserter.inserter import insert_xai_branch_into_model -from openvino_xai.methods.base import OVMethod +from openvino_xai.methods.base import MethodBase -class WhiteBoxMethod(OVMethod): +class WhiteBoxMethod(MethodBase[ov.Model, ov.CompiledModel]): """ Base class for white-box XAI methods. @@ -64,7 +64,7 @@ def prepare_model(self, load_model: bool = True) -> ov.Model: logger.info("Provided IR model already contains XAI branch.") self._model = self._model_ori if load_model: - self.load_model() + self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name) return self._model xai_output_node = self.generate_xai_branch() @@ -72,7 +72,7 @@ def prepare_model(self, load_model: bool = True) -> ov.Model: if not has_xai(self._model): raise RuntimeError("Insertion of the XAI branch into the model was not successful.") if load_model: - self.load_model() + self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name) return self._model @staticmethod diff --git a/openvino_xai/methods/white_box/recipro_cam.py b/openvino_xai/methods/white_box/recipro_cam.py index 413dccde..e5ae6f0b 100644 --- a/openvino_xai/methods/white_box/recipro_cam.py +++ b/openvino_xai/methods/white_box/recipro_cam.py @@ -86,7 +86,7 @@ def __new__( **kwargs, ): if isinstance(model, torch.nn.Module): - from .torch import ReciproCAM as TorchReciproCAM + from .torch import TorchReciproCAM return TorchReciproCAM(model, *args, **kwargs) return super().__new__(cls) @@ -191,7 +191,7 @@ def __new__( **kwargs, ): if isinstance(model, torch.nn.Module): - from .torch import ViTReciproCAM as TorchViTReciproCAM + from .torch import TorchViTReciproCAM return TorchViTReciproCAM(model, *args, **kwargs) return super().__new__(cls) diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index 01c19262..a1e8e80f 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -10,11 +10,11 @@ import numpy as np import torch -from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME +from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME, has_xai from openvino_xai.methods.base import IdentityPreprocessFN, MethodBase -class TorchMethod(MethodBase[torch.nn.Module, torch.nn.Module]): +class TorchWhiteBoxMethod(MethodBase[torch.nn.Module, torch.nn.Module]): """ Base class for Torch-based methods. @@ -49,6 +49,11 @@ def __init__( def prepare_model(self, load_model: bool = True) -> torch.nn.Module: """Return XAI inserted model.""" + if has_xai(self._model): + if load_model: + self._model_compiled = self._model + return self._model + model = copy.deepcopy(self._model) # Feature feature_layer = model.get_submodule(self._target_layer) @@ -57,7 +62,9 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module: model.register_forward_hook(self._output_hook) setattr(model, "has_xai", True) model.eval() - self._model_compiled = model + + if load_model: + self._model_compiled = model return model def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: @@ -103,7 +110,7 @@ def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor: return saliency_map.to(torch.uint8) -class ActivationMap(TorchMethod): +class TorchActivationMap(TorchWhiteBoxMethod): """ActivationMap. Mean of the feature map along the channel dimension.""" def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: @@ -120,7 +127,7 @@ def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tenso } -class ReciproCAM(TorchMethod): +class TorchReciproCAM(TorchWhiteBoxMethod): """Implementation of Recipro-CAM for class-wise saliency map. Recipro-CAM: gradient-free reciprocal class activation map (https://arxiv.org/pdf/2209.14074.pdf) @@ -181,7 +188,7 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w: return mosaic_feature_map -class ViTReciproCAM(ReciproCAM): +class TorchViTReciproCAM(TorchReciproCAM): """Implementation of ViTRecipro-CAM for class-wise saliency map for transformer-based classifiers. ViT-ReciproCAM: Gradient and Attention-Free Visual Explanations for Vision Transformer diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index 8032768d..fa374f6c 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -11,21 +11,21 @@ from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME, has_xai from openvino_xai.methods.white_box.torch import ( - ActivationMap, - ReciproCAM, - TorchMethod, - ViTReciproCAM, + TorchActivationMap, + TorchReciproCAM, + TorchViTReciproCAM, + TorchWhiteBoxMethod, ) def test_normalize(): x = torch.rand((2, 2)) * 100 - y = TorchMethod._normalize_map(x) + y = TorchWhiteBoxMethod._normalize_map(x) assert x.shape == y.shape assert torch.all(y >= 0) assert torch.all(y <= 255) x = torch.rand((2, 2, 2)) * 100 - y = TorchMethod._normalize_map(x) + y = TorchWhiteBoxMethod._normalize_map(x) assert x.shape == y.shape assert torch.all(y >= 0) assert torch.all(y <= 255) @@ -66,7 +66,7 @@ def forward(self, x: torch.Tensor): def test_torch_method(): model = DummyCNN() - method = TorchMethod(model=model, target_layer="feature") + method = TorchWhiteBoxMethod(model=model, target_layer="feature") model_xai = method.prepare_model() assert has_xai(model_xai) data = np.zeros((1, 3, 5, 5)) @@ -74,7 +74,7 @@ def test_torch_method(): assert type(output) == dict assert SALIENCY_MAP_OUTPUT_NAME in output - class DummyMethod(TorchMethod): + class DummyMethod(TorchWhiteBoxMethod): def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: output = torch.cat((output, output), dim=0) return super()._feature_hook(module, inputs, output) @@ -99,11 +99,28 @@ def _output_hook( assert np.all(saliency_maps == prediction) +def test_prepare_model(): + model = DummyCNN() + method = TorchWhiteBoxMethod(model=model, target_layer="feature") + model_xai = method.prepare_model(load_model=False) + assert method._model_compiled is None + assert model is not model_xai + + model_xai = method.prepare_model(load_model=True) + assert method._model_compiled is not None + assert model is not model_xai + + model.has_xai = True + method = TorchWhiteBoxMethod(model=model, target_layer="feature") + model_xai = method.prepare_model(load_model=False) + assert model_xai == model + + def test_activationmap() -> None: batch_size = 2 num_classes = 3 model = DummyCNN(num_classes=num_classes) - method = ActivationMap(model=model, target_layer="feature") + method = TorchActivationMap(model=model, target_layer="feature") model_xai = method.prepare_model() assert has_xai(model_xai) data = np.random.rand(batch_size, 3, 5, 5) @@ -121,7 +138,7 @@ def test_reciprocam(optimize_gap: bool) -> None: batch_size = 2 num_classes = 3 model = DummyCNN(num_classes=num_classes) - method = ReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap) + method = TorchReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap) model_xai = method.prepare_model() assert has_xai(model_xai) data = np.random.rand(batch_size, 4, 5, 5) @@ -140,7 +157,9 @@ def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None: batch_size = 2 num_classes = 3 model = DummyVIT(num_classes=num_classes) - method = ViTReciproCAM(model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token) + method = TorchViTReciproCAM( + model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token + ) model_xai = method.prepare_model() assert has_xai(model_xai) data = np.random.rand(batch_size, 4, 5, 5)