Skip to content

Commit

Permalink
Enable torch method creation via factory
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Sep 4, 2024
1 parent 9bfed70 commit b1fc71e
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 14 deletions.
11 changes: 4 additions & 7 deletions openvino_xai/methods/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,6 @@ def create_detection_method(
:type embed_scaling: bool
"""

if isinstance(model, torch.nn.Module):
raise NotImplementedError("Torch models are not yet supported by detection white box methods.")

if target_layer is None:
raise ValueError("target_layer is required for the detection.")

Expand All @@ -194,7 +191,7 @@ class BlackBoxMethodFactory(MethodFactory):
def create_method(
cls,
task: Task,
model: ov.Model,
model: ov.Model | torch.nn.Module,
postprocess_fn: Callable[[Mapping], np.ndarray],
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
explain_method: Method | None = None,
Expand All @@ -213,7 +210,7 @@ def create_method(

@staticmethod
def create_classification_method(
model: ov.Model,
model: ov.Model | torch.nn.Module,
postprocess_fn: Callable[[Mapping], np.ndarray],
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
explain_method: Method | None = None,
Expand All @@ -224,7 +221,7 @@ def create_classification_method(
Using AISE as a default method.
:param model: Input model.
:type model: ov.Model
:type model: ov.Model | torch.nn.Module
:param postprocess_fn: Preprocessing function that extract scores from model output.
:type postprocess_fn: Callable[[Mapping], np.ndarray]
:param preprocess_fn: Preprocessing function, identity function by default
Expand All @@ -241,7 +238,7 @@ def create_classification_method(

@staticmethod
def create_detection_method(
model: ov.Model,
model: ov.Model | torch.nn.Module,
postprocess_fn: Callable[[Mapping], np.ndarray],
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
explain_method: Method | None = None,
Expand Down
13 changes: 13 additions & 0 deletions openvino_xai/methods/white_box/activation_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import openvino.runtime as ov
import torch
from openvino.runtime import opset10 as opset

from openvino_xai.common.utils import IdentityPreprocessFN
Expand All @@ -31,6 +32,18 @@ class ActivationMap(WhiteBoxMethod):
:type prepare_model: bool
"""

def __new__(
cls,
model: ov.Model | torch.nn.Module | None = None,
*args,
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import ActivationMap as TorchActivationMap

return TorchActivationMap(model, *args, **kwargs)
return super().__new__(cls)

def __init__(
self,
model: ov.Model,
Expand Down
25 changes: 25 additions & 0 deletions openvino_xai/methods/white_box/recipro_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import openvino.runtime as ov
import torch
from openvino.runtime import opset10 as opset

from openvino_xai.common.utils import IdentityPreprocessFN
Expand Down Expand Up @@ -78,6 +79,18 @@ class ReciproCAM(FeatureMapPerturbationBase):
:type prepare_model: bool
"""

def __new__(
cls,
model: ov.Model | torch.nn.Module | None = None,
*args,
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import ReciproCAM as TorchReciproCAM

return TorchReciproCAM(model, *args, **kwargs)
return super().__new__(cls)

def __init__(
self,
model: ov.Model,
Expand Down Expand Up @@ -171,6 +184,18 @@ class ViTReciproCAM(FeatureMapPerturbationBase):
:type prepare_model: bool
"""

def __new__(
cls,
model: ov.Model | torch.nn.Module | None = None,
*args,
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import ViTReciproCAM as TorchViTReciproCAM

return TorchViTReciproCAM(model, *args, **kwargs)
return super().__new__(cls)

def __init__(
self,
model: ov.Model,
Expand Down
7 changes: 4 additions & 3 deletions openvino_xai/methods/white_box/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
target_layer: str | None = None,
embed_scaling: bool = True,
device_name: str = "CPU",
**kwargs,
):
super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name)
self._target_layer = target_layer
Expand Down Expand Up @@ -224,8 +225,8 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w:
if self._use_gaussian:
if self._use_cls_token:
mosaic_feature_map[:, 0, :] = feature_map[0, :]
feature_map_spacial = feature_map[1:, :].reshape(1, h, w, c)
feature_map_spacial_repeated = feature_map_spacial.repeat(h * w, 1, 1, 1) # 196, 14, 14, 192
feature_map_spatial = feature_map[1:, :].reshape(1, h, w, c)
feature_map_spatial_repeated = feature_map_spatial.repeat(h * w, 1, 1, 1) # 196, 14, 14, 192

spatial_order = torch.arange(h * w).reshape(h, w)
gaussian = torch.tensor(
Expand All @@ -241,7 +242,7 @@ def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w:
mosaic_feature_map_mask = mosaic_feature_map_mask_padded[:, 1:-1, 1:-1]
mosaic_feature_map_mask = mosaic_feature_map_mask.unsqueeze(3).repeat(1, 1, 1, c)

mosaic_fm_wo_cls_token = feature_map_spacial_repeated * mosaic_feature_map_mask
mosaic_fm_wo_cls_token = feature_map_spatial_repeated * mosaic_feature_map_mask
mosaic_feature_map[:, 1:, :] = mosaic_fm_wo_cls_token.reshape(h * w, h * w, c)
else:
feature_map_repeated = feature_map.unsqueeze(0).repeat(h * w, 1, 1)
Expand Down
30 changes: 26 additions & 4 deletions tests/unit/methods/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn
from openvino_xai.methods.black_box.aise.classification import AISEClassification
from openvino_xai.methods.factory import BlackBoxMethodFactory, WhiteBoxMethodFactory
from openvino_xai.methods.white_box import torch as torch_method
from openvino_xai.methods.white_box.activation_map import ActivationMap
from openvino_xai.methods.white_box.det_class_probability_map import (
DetClassProbabilityMap,
Expand Down Expand Up @@ -151,15 +152,36 @@ def test_create_wb_det_cnn_method(fxt_data_root: Path):


def test_create_torch_method():
model = torch.nn.Module()
with pytest.raises(NotImplementedError):
explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn())
model = {}
with pytest.raises(ValueError):
explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn())
model = torch.nn.Module()
with pytest.raises(NotImplementedError):
explain_method = WhiteBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn())
explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn())
with pytest.raises(NotImplementedError):
explain_method = BlackBoxMethodFactory.create_method(
Task.DETECTION, model, get_postprocess_fn(), target_layer=""
)

model = {}
with pytest.raises(ValueError):
explain_method = WhiteBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn())
model = torch.nn.Module()
with pytest.raises(NotImplementedError):
explain_method = WhiteBoxMethodFactory.create_method(
Task.DETECTION, model, get_postprocess_fn(), target_layer=""
)

model = torch.nn.Module()
explain_method = WhiteBoxMethodFactory.create_method(
Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.ACTIVATIONMAP
)
assert isinstance(explain_method, torch_method.ActivationMap)
explain_method = WhiteBoxMethodFactory.create_method(
Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.RECIPROCAM
)
assert isinstance(explain_method, torch_method.ReciproCAM)
explain_method = WhiteBoxMethodFactory.create_method(
Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.VITRECIPROCAM
)
assert isinstance(explain_method, torch_method.ViTReciproCAM)

0 comments on commit b1fc71e

Please sign in to comment.