Skip to content

Commit

Permalink
Refactor torch method class hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Sep 5, 2024
1 parent d93223a commit 3fa9b04
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 33 deletions.
6 changes: 0 additions & 6 deletions openvino_xai/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np.
"""Saliency map generation."""


class OVMethod(MethodBase[ov.Model, ov.CompiledModel]):
def load_model(self) -> None:
core = ov.Core()
self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name)


@dataclass
class Prediction:
label: int | None = None
Expand Down
6 changes: 3 additions & 3 deletions openvino_xai/methods/black_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import openvino.runtime as ov

from openvino_xai.common.utils import IdentityPreprocessFN
from openvino_xai.methods.base import OVMethod
from openvino_xai.methods.base import MethodBase
from openvino_xai.methods.black_box.utils import check_classification_output


class BlackBoxXAIMethod(OVMethod):
class BlackBoxXAIMethod(MethodBase[ov.Model, ov.CompiledModel]):
"""Base class for methods that explain model in Black-Box mode."""

def __init__(
Expand All @@ -28,7 +28,7 @@ def __init__(
def prepare_model(self, load_model: bool = True) -> ov.Model:
"""Load model prior to inference."""
if load_model:
self.load_model()
self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name)
return self._model

def get_logits(self, data_preprocessed: np.ndarray) -> np.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion openvino_xai/methods/white_box/activation_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __new__(
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import ActivationMap as TorchActivationMap
from .torch import TorchActivationMap

return TorchActivationMap(model, *args, **kwargs)
return super().__new__(cls)
Expand Down
8 changes: 4 additions & 4 deletions openvino_xai/methods/white_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
has_xai,
)
from openvino_xai.inserter.inserter import insert_xai_branch_into_model
from openvino_xai.methods.base import OVMethod
from openvino_xai.methods.base import MethodBase


class WhiteBoxMethod(OVMethod):
class WhiteBoxMethod(MethodBase[ov.Model, ov.CompiledModel]):
"""
Base class for white-box XAI methods.
Expand Down Expand Up @@ -64,15 +64,15 @@ def prepare_model(self, load_model: bool = True) -> ov.Model:
logger.info("Provided IR model already contains XAI branch.")
self._model = self._model_ori
if load_model:
self.load_model()
self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name)
return self._model

xai_output_node = self.generate_xai_branch()
self._model = insert_xai_branch_into_model(self._model_ori, xai_output_node, self.embed_scaling)
if not has_xai(self._model):
raise RuntimeError("Insertion of the XAI branch into the model was not successful.")
if load_model:
self.load_model()
self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name)
return self._model

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions openvino_xai/methods/white_box/recipro_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __new__(
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import ReciproCAM as TorchReciproCAM
from .torch import TorchReciproCAM

return TorchReciproCAM(model, *args, **kwargs)
return super().__new__(cls)
Expand Down Expand Up @@ -191,7 +191,7 @@ def __new__(
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import ViTReciproCAM as TorchViTReciproCAM
from .torch import TorchViTReciproCAM

return TorchViTReciproCAM(model, *args, **kwargs)
return super().__new__(cls)
Expand Down
19 changes: 13 additions & 6 deletions openvino_xai/methods/white_box/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import numpy as np
import torch

from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME
from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME, has_xai
from openvino_xai.methods.base import IdentityPreprocessFN, MethodBase


class TorchMethod(MethodBase[torch.nn.Module, torch.nn.Module]):
class TorchWhiteBoxMethod(MethodBase[torch.nn.Module, torch.nn.Module]):
"""
Base class for Torch-based methods.
Expand Down Expand Up @@ -49,6 +49,11 @@ def __init__(

def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
"""Return XAI inserted model."""
if has_xai(self._model):
if load_model:
self._model_compiled = self._model
return self._model

model = copy.deepcopy(self._model)
# Feature
feature_layer = model.get_submodule(self._target_layer)
Expand All @@ -57,7 +62,9 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
model.register_forward_hook(self._output_hook)
setattr(model, "has_xai", True)
model.eval()
self._model_compiled = model

if load_model:
self._model_compiled = model
return model

def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping:
Expand Down Expand Up @@ -103,7 +110,7 @@ def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor:
return saliency_map.to(torch.uint8)


class ActivationMap(TorchMethod):
class TorchActivationMap(TorchWhiteBoxMethod):
"""ActivationMap. Mean of the feature map along the channel dimension."""

def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
Expand All @@ -120,7 +127,7 @@ def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tenso
}


class ReciproCAM(TorchMethod):
class TorchReciproCAM(TorchWhiteBoxMethod):
"""Implementation of Recipro-CAM for class-wise saliency map.
Recipro-CAM: gradient-free reciprocal class activation map (https://arxiv.org/pdf/2209.14074.pdf)
Expand Down Expand Up @@ -181,7 +188,7 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w:
return mosaic_feature_map


class ViTReciproCAM(ReciproCAM):
class TorchViTReciproCAM(TorchReciproCAM):
"""Implementation of ViTRecipro-CAM for class-wise saliency map for transformer-based classifiers.
ViT-ReciproCAM: Gradient and Attention-Free Visual Explanations for Vision Transformer
Expand Down
41 changes: 30 additions & 11 deletions tests/unit/methods/white_box/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@

from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME, has_xai
from openvino_xai.methods.white_box.torch import (
ActivationMap,
ReciproCAM,
TorchMethod,
ViTReciproCAM,
TorchActivationMap,
TorchReciproCAM,
TorchViTReciproCAM,
TorchWhiteBoxMethod,
)


def test_normalize():
x = torch.rand((2, 2)) * 100
y = TorchMethod._normalize_map(x)
y = TorchWhiteBoxMethod._normalize_map(x)
assert x.shape == y.shape
assert torch.all(y >= 0)
assert torch.all(y <= 255)
x = torch.rand((2, 2, 2)) * 100
y = TorchMethod._normalize_map(x)
y = TorchWhiteBoxMethod._normalize_map(x)
assert x.shape == y.shape
assert torch.all(y >= 0)
assert torch.all(y <= 255)
Expand Down Expand Up @@ -66,15 +66,15 @@ def forward(self, x: torch.Tensor):

def test_torch_method():
model = DummyCNN()
method = TorchMethod(model=model, target_layer="feature")
method = TorchWhiteBoxMethod(model=model, target_layer="feature")
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.zeros((1, 3, 5, 5))
output = method.model_forward(data)
assert type(output) == dict
assert SALIENCY_MAP_OUTPUT_NAME in output

class DummyMethod(TorchMethod):
class DummyMethod(TorchWhiteBoxMethod):
def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
output = torch.cat((output, output), dim=0)
return super()._feature_hook(module, inputs, output)
Expand All @@ -99,11 +99,28 @@ def _output_hook(
assert np.all(saliency_maps == prediction)


def test_prepare_model():
model = DummyCNN()
method = TorchWhiteBoxMethod(model=model, target_layer="feature")
model_xai = method.prepare_model(load_model=False)
assert method._model_compiled is None
assert model is not model_xai

model_xai = method.prepare_model(load_model=True)
assert method._model_compiled is not None
assert model is not model_xai

model.has_xai = True
method = TorchWhiteBoxMethod(model=model, target_layer="feature")
model_xai = method.prepare_model(load_model=False)
assert model_xai == model


def test_activationmap() -> None:
batch_size = 2
num_classes = 3
model = DummyCNN(num_classes=num_classes)
method = ActivationMap(model=model, target_layer="feature")
method = TorchActivationMap(model=model, target_layer="feature")
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(batch_size, 3, 5, 5)
Expand All @@ -121,7 +138,7 @@ def test_reciprocam(optimize_gap: bool) -> None:
batch_size = 2
num_classes = 3
model = DummyCNN(num_classes=num_classes)
method = ReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap)
method = TorchReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(batch_size, 4, 5, 5)
Expand All @@ -140,7 +157,9 @@ def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None:
batch_size = 2
num_classes = 3
model = DummyVIT(num_classes=num_classes)
method = ViTReciproCAM(model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token)
method = TorchViTReciproCAM(
model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token
)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(batch_size, 4, 5, 5)
Expand Down

0 comments on commit 3fa9b04

Please sign in to comment.