Skip to content

Commit

Permalink
Refactor torch method class hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Sep 5, 2024
1 parent d93223a commit ce43d6b
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 26 deletions.
6 changes: 0 additions & 6 deletions openvino_xai/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np.
"""Saliency map generation."""


class OVMethod(MethodBase[ov.Model, ov.CompiledModel]):
def load_model(self) -> None:
core = ov.Core()
self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name)


@dataclass
class Prediction:
label: int | None = None
Expand Down
6 changes: 3 additions & 3 deletions openvino_xai/methods/black_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import openvino.runtime as ov

from openvino_xai.common.utils import IdentityPreprocessFN
from openvino_xai.methods.base import OVMethod
from openvino_xai.methods.base import MethodBase
from openvino_xai.methods.black_box.utils import check_classification_output


class BlackBoxXAIMethod(OVMethod):
class BlackBoxXAIMethod(MethodBase[ov.Model, ov.CompiledModel]):
"""Base class for methods that explain model in Black-Box mode."""

def __init__(
Expand All @@ -28,7 +28,7 @@ def __init__(
def prepare_model(self, load_model: bool = True) -> ov.Model:
"""Load model prior to inference."""
if load_model:
self.load_model()
self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name)
return self._model

def get_logits(self, data_preprocessed: np.ndarray) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion openvino_xai/methods/white_box/activation_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __new__(
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import ActivationMap as TorchActivationMap
from .torch import TorchActivationMap

return TorchActivationMap(model, *args, **kwargs)
return super().__new__(cls)
Expand Down
8 changes: 4 additions & 4 deletions openvino_xai/methods/white_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
has_xai,
)
from openvino_xai.inserter.inserter import insert_xai_branch_into_model
from openvino_xai.methods.base import OVMethod
from openvino_xai.methods.base import MethodBase


class WhiteBoxMethod(OVMethod):
class WhiteBoxMethod(MethodBase[ov.Model, ov.CompiledModel]):
"""
Base class for white-box XAI methods.
Expand Down Expand Up @@ -64,15 +64,15 @@ def prepare_model(self, load_model: bool = True) -> ov.Model:
logger.info("Provided IR model already contains XAI branch.")
self._model = self._model_ori
if load_model:
self.load_model()
self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name)
return self._model

xai_output_node = self.generate_xai_branch()
self._model = insert_xai_branch_into_model(self._model_ori, xai_output_node, self.embed_scaling)
if not has_xai(self._model):
raise RuntimeError("Insertion of the XAI branch into the model was not successful.")
if load_model:
self.load_model()
self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name)
return self._model

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions openvino_xai/methods/white_box/recipro_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __new__(
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import ReciproCAM as TorchReciproCAM
from .torch import TorchReciproCAM

return TorchReciproCAM(model, *args, **kwargs)
return super().__new__(cls)
Expand Down Expand Up @@ -191,7 +191,7 @@ def __new__(
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import ViTReciproCAM as TorchViTReciproCAM
from .torch import TorchViTReciproCAM

return TorchViTReciproCAM(model, *args, **kwargs)
return super().__new__(cls)
Expand Down
8 changes: 4 additions & 4 deletions openvino_xai/methods/white_box/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from openvino_xai.methods.base import IdentityPreprocessFN, MethodBase


class TorchMethod(MethodBase[torch.nn.Module, torch.nn.Module]):
class TorchWhiteBoxMethod(MethodBase[torch.nn.Module, torch.nn.Module]):
"""
Base class for Torch-based methods.
Expand Down Expand Up @@ -103,7 +103,7 @@ def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor:
return saliency_map.to(torch.uint8)


class ActivationMap(TorchMethod):
class TorchActivationMap(TorchWhiteBoxMethod):
"""ActivationMap. Mean of the feature map along the channel dimension."""

def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
Expand All @@ -120,7 +120,7 @@ def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tenso
}


class ReciproCAM(TorchMethod):
class TorchReciproCAM(TorchWhiteBoxMethod):
"""Implementation of Recipro-CAM for class-wise saliency map.
Recipro-CAM: gradient-free reciprocal class activation map (https://arxiv.org/pdf/2209.14074.pdf)
Expand Down Expand Up @@ -181,7 +181,7 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w:
return mosaic_feature_map


class ViTReciproCAM(ReciproCAM):
class TorchViTReciproCAM(TorchReciproCAM):
"""Implementation of ViTRecipro-CAM for class-wise saliency map for transformer-based classifiers.
ViT-ReciproCAM: Gradient and Attention-Free Visual Explanations for Vision Transformer
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/methods/white_box/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME, has_xai
from openvino_xai.methods.white_box.torch import (
ActivationMap,
ReciproCAM,
TorchActivationMap,
TorchMethod,
ViTReciproCAM,
TorchReciproCAM,
TorchViTReciproCAM,
)


Expand Down Expand Up @@ -103,7 +103,7 @@ def test_activationmap() -> None:
batch_size = 2
num_classes = 3
model = DummyCNN(num_classes=num_classes)
method = ActivationMap(model=model, target_layer="feature")
method = TorchActivationMap(model=model, target_layer="feature")
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(batch_size, 3, 5, 5)
Expand All @@ -121,7 +121,7 @@ def test_reciprocam(optimize_gap: bool) -> None:
batch_size = 2
num_classes = 3
model = DummyCNN(num_classes=num_classes)
method = ReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap)
method = TorchReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(batch_size, 4, 5, 5)
Expand All @@ -140,7 +140,9 @@ def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None:
batch_size = 2
num_classes = 3
model = DummyVIT(num_classes=num_classes)
method = ViTReciproCAM(model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token)
method = TorchViTReciproCAM(
model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token
)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(batch_size, 4, 5, 5)
Expand Down

0 comments on commit ce43d6b

Please sign in to comment.