Skip to content

Commit

Permalink
Support Pytorch models for insert_xai API (#61)
Browse files Browse the repository at this point in the history
* Update build dependency with torch support

* Apply generic interface for torch model support

* Extend has_xai() for torch

* Align unit test structure w/ src code

* Add default __new__() for MethodBase

* Copy torch XAI algos from OTX

* Implement inserter for torch models

* Implement ActivationMap torch hooks

* Implement ReciproCAM torch hooks

* Fix shape mismatch w/ optimize_gap=True

* Implement ViTReciproCAM torch hooks

* Enable torch method creation via factory

* Add integration test

* Refactor torch method class hierarchy
  • Loading branch information
goodsong81 authored Sep 5, 2024
1 parent f637a5f commit 8b0ddf9
Show file tree
Hide file tree
Showing 28 changed files with 665 additions and 64 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ jobs:
- name: Install tox
run: python -m pip install tox==4.4.6
- name: Run Functional Test
run: tox -vv -e val-py310 -- -v tests/func --csv=.tox/val-py310/func-test.csv -n 1 --max-worker-restart 100 --clear-cache
run: tox -vv -e dev-py310 -- -v tests/func --csv=.tox/dev-py310/func-test.csv -n 1 --max-worker-restart 100 --clear-cache
- name: Upload artifacts
uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1
with:
name: func-test-results
path: .tox/val-py310/*.csv
path: .tox/dev-py310/*.csv
# Use always() to always run this step to publish test results when there are test failures
if: ${{ always() }}
2 changes: 1 addition & 1 deletion .github/workflows/pre_merge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
- name: Install tox
run: python -m pip install tox==4.4.6
- name: Run Integration Test
run: tox -vv -e val-py310 -- tests/intg --csv=.tox/dev-py310/intg-test.csv -n 1 --clear-cache
run: tox -vv -e dev-py310 -- tests/intg --csv=.tox/dev-py310/intg-test.csv -n 1 --clear-cache
- name: Upload artifacts
uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1
with:
Expand Down
17 changes: 10 additions & 7 deletions openvino_xai/api/api.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import List
from typing import List, TypeVar

import openvino as ov
import torch

from openvino_xai.common.parameters import Method, Task
from openvino_xai.common.utils import IdentityPreprocessFN, has_xai, logger
from openvino_xai.methods.factory import WhiteBoxMethodFactory

Model = TypeVar("Model", ov.Model, torch.nn.Module)


def insert_xai(
model: ov.Model,
model: Model,
task: Task,
explain_method: Method | None = None,
target_layer: str | List[str] | None = None,
embed_scaling: bool | None = True,
**kwargs,
) -> ov.Model:
) -> Model:
"""
Function that inserts XAI branch into IR.
Inserts XAI branch into the given model.
Usage:
model_xai = openvino_xai.insert_xai(model, task=Task.CLASSIFICATION)
:param model: Original IR.
:type model: ov.Model | str
:param model: Original model.
:type model: ov.Model | torch.nn.Module
:param task: Type of the task: CLASSIFICATION or DETECTION.
:type task: Task
:parameter explain_method: Explain method to use for model explanation.
Expand All @@ -37,7 +40,7 @@ def insert_xai(
"""

if has_xai(model):
logger.info("Provided IR model already contains XAI branch, return it as-is.")
logger.info("Provided model already contains XAI branch, return it as-is.")
return model

method = WhiteBoxMethodFactory.create_method(
Expand Down
24 changes: 14 additions & 10 deletions openvino_xai/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from urllib.request import urlretrieve

import numpy as np
import openvino.runtime as ov
import openvino as ov
import torch

logger = logging.getLogger("openvino_xai")
handler = logging.StreamHandler()
Expand All @@ -23,20 +24,23 @@
SALIENCY_MAP_OUTPUT_NAME = "saliency_map"


def has_xai(model: ov.Model) -> bool:
def has_xai(model: ov.Model | torch.nn.Module) -> bool:
"""
Function checks if the model contains XAI branch.
:param model: OV IR model.
:type model: ov.Model
:param model: Input model for inspect.
:type model: ov.Model | torch.nn.Module
:return: True is the model has XAI branch and saliency_map output, False otherwise.
"""
if not isinstance(model, ov.Model):
raise ValueError(f"Input model has to be ov.Model instance, but got{type(model)}.")
for output in model.outputs:
if SALIENCY_MAP_OUTPUT_NAME in output.get_names():
return True
return False
if isinstance(model, ov.Model):
for output in model.outputs:
if SALIENCY_MAP_OUTPUT_NAME in output.get_names():
return True
return False
elif isinstance(model, torch.nn.Module):
return getattr(model, "has_xai", False)
else:
raise ValueError(f"Input model has to be openvino.Model or torch.nn.Module instance, but got{type(model)}.")


# Not a part of product
Expand Down
9 changes: 5 additions & 4 deletions openvino_xai/inserter/inserter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import openvino.runtime as ov

import openvino as ov
from openvino.preprocess import PrePostProcessor

from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME


def insert_xai_branch_into_model(
model: ov.Model,
xai_output_node,
set_uint8,
xai_output_node: ov.runtime.Node,
set_uint8: bool,
) -> ov.Model:
"""Creates new model with XAI branch."""
"""Create new model with XAI branch."""
model_ori_outputs = model.outputs
model_ori_params = model.get_parameters()
model_xai = ov.Model([*model_ori_outputs, xai_output_node.output(0)], model_ori_params)
Expand Down
33 changes: 23 additions & 10 deletions openvino_xai/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,38 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Dict, List, Mapping, Tuple
from typing import Callable, Dict, Generic, List, Mapping, Tuple, TypeAlias, TypeVar

import numpy as np
import openvino as ov
import torch

from openvino_xai.common.utils import IdentityPreprocessFN

Model = TypeVar("Model", ov.Model, torch.nn.Module)
CompiledModel = TypeVar("CompiledModel", ov.CompiledModel, torch.nn.Module)
PreprocessFn: TypeAlias = Callable[[np.ndarray], np.ndarray]

class MethodBase(ABC):

class MethodBase(ABC, Generic[Model, CompiledModel]):
"""Base class for XAI methods."""

def __new__(
cls,
model: Model | None = None,
*args,
**kwargs,
):
if isinstance(model, torch.nn.Module):
raise NotImplementedError(f"{type(model)} is not yet supported for {cls}")
elif model is not None and not isinstance(model, ov.Model):
raise ValueError(f"{type(model)} is not supported")
return super().__new__(cls)

def __init__(
self,
model: ov.Model = None,
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
model: Model | None = None,
preprocess_fn: PreprocessFn = IdentityPreprocessFN(),
device_name: str = "CPU",
):
self._model = model
Expand All @@ -27,11 +44,11 @@ def __init__(
self.predictions: Dict[int, Prediction] = {}

@property
def model_compiled(self) -> ov.CompiledModel | None:
def model_compiled(self) -> CompiledModel | None:
return self._model_compiled

@abstractmethod
def prepare_model(self, load_model: bool = True) -> ov.Model:
def prepare_model(self, load_model: bool = True) -> Model:
"""Model preparation steps."""

def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping:
Expand All @@ -46,10 +63,6 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping:
def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np.ndarray:
"""Saliency map generation."""

def load_model(self) -> None:
core = ov.Core()
self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name)


@dataclass
class Prediction:
Expand Down
4 changes: 2 additions & 2 deletions openvino_xai/methods/black_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from openvino_xai.methods.black_box.utils import check_classification_output


class BlackBoxXAIMethod(MethodBase):
class BlackBoxXAIMethod(MethodBase[ov.Model, ov.CompiledModel]):
"""Base class for methods that explain model in Black-Box mode."""

def __init__(
Expand All @@ -28,7 +28,7 @@ def __init__(
def prepare_model(self, load_model: bool = True) -> ov.Model:
"""Load model prior to inference."""
if load_model:
self.load_model()
self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name)
return self._model

def get_logits(self, data_preprocessed: np.ndarray) -> np.ndarray:
Expand Down
31 changes: 16 additions & 15 deletions openvino_xai/methods/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from typing import Callable, List, Mapping

import numpy as np
import openvino.runtime as ov
import openvino as ov
import torch

from openvino_xai.common.parameters import Method, Task
from openvino_xai.common.utils import IdentityPreprocessFN, logger
Expand Down Expand Up @@ -44,7 +45,7 @@ class WhiteBoxMethodFactory(MethodFactory):
def create_method(
cls,
task: Task,
model: ov.Model,
model: ov.Model | torch.nn.Module,
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
explain_method: Method | None = None,
target_layer: str | List[str] | None = None,
Expand Down Expand Up @@ -72,11 +73,11 @@ def create_method(
device_name,
**kwargs,
)
raise ValueError(f"Model type {task} is not supported in white-box mode.")
raise ValueError(f"Task type {task} is not supported in white-box mode.")

@staticmethod
def create_classification_method(
model: ov.Model,
model: ov.Model | torch.nn.Module,
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
explain_method: Method | None = None,
target_layer: str | None = None,
Expand All @@ -86,8 +87,8 @@ def create_classification_method(
) -> WhiteBoxMethod:
"""Generates instance of the classification white-box method class.
:param model: OV IR model.
:type model: ov.Model
:param model: Input model.
:type model: ov.Model | torch.nn.Module
:param preprocess_fn: Preprocessing function, identity function by default
(assume input images are already preprocessed by user).
:type preprocess_fn: Callable[[np.ndarray], np.ndarray]
Expand Down Expand Up @@ -147,7 +148,7 @@ def create_classification_method(

@staticmethod
def create_detection_method(
model: ov.Model,
model: ov.Model | torch.nn.Module,
preprocess_fn: Callable[[np.ndarray], np.ndarray],
explain_method: Method | None = None,
target_layer: List[str] | None = None,
Expand All @@ -157,8 +158,8 @@ def create_detection_method(
) -> WhiteBoxMethod:
"""Generates instance of the detection white-box method class.
:param model: OV IR model.
:type model: ov.Model
:param model: Input model.
:type model: ov.Model | torch.nn.Module
:param preprocess_fn: Preprocessing function, identity function by default
(assume input images are already preprocessed by user).
:type preprocess_fn: Callable[[np.ndarray], np.ndarray]
Expand Down Expand Up @@ -190,7 +191,7 @@ class BlackBoxMethodFactory(MethodFactory):
def create_method(
cls,
task: Task,
model: ov.Model,
model: ov.Model | torch.nn.Module,
postprocess_fn: Callable[[Mapping], np.ndarray],
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
explain_method: Method | None = None,
Expand All @@ -209,7 +210,7 @@ def create_method(

@staticmethod
def create_classification_method(
model: ov.Model,
model: ov.Model | torch.nn.Module,
postprocess_fn: Callable[[Mapping], np.ndarray],
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
explain_method: Method | None = None,
Expand All @@ -219,9 +220,9 @@ def create_classification_method(
"""Generates instance of the classification black-box method class.
Using AISE as a default method.
:param model: OV IR model.
:type model: ov.Model
:param postprocess_fn: Preprocessing function that extract scores from IR model output.
:param model: Input model.
:type model: ov.Model | torch.nn.Module
:param postprocess_fn: Preprocessing function that extract scores from model output.
:type postprocess_fn: Callable[[Mapping], np.ndarray]
:param preprocess_fn: Preprocessing function, identity function by default
(assume input images are already preprocessed by user).
Expand All @@ -237,7 +238,7 @@ def create_classification_method(

@staticmethod
def create_detection_method(
model: ov.Model,
model: ov.Model | torch.nn.Module,
postprocess_fn: Callable[[Mapping], np.ndarray],
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
explain_method: Method | None = None,
Expand Down
13 changes: 13 additions & 0 deletions openvino_xai/methods/white_box/activation_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import openvino.runtime as ov
import torch
from openvino.runtime import opset10 as opset

from openvino_xai.common.utils import IdentityPreprocessFN
Expand All @@ -31,6 +32,18 @@ class ActivationMap(WhiteBoxMethod):
:type prepare_model: bool
"""

def __new__(
cls,
model: ov.Model | torch.nn.Module | None = None,
*args,
**kwargs,
):
if isinstance(model, torch.nn.Module):
from .torch import TorchActivationMap

return TorchActivationMap(model, *args, **kwargs)
return super().__new__(cls)

def __init__(
self,
model: ov.Model,
Expand Down
6 changes: 3 additions & 3 deletions openvino_xai/methods/white_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from openvino_xai.methods.base import MethodBase


class WhiteBoxMethod(MethodBase):
class WhiteBoxMethod(MethodBase[ov.Model, ov.CompiledModel]):
"""
Base class for white-box XAI methods.
Expand Down Expand Up @@ -64,15 +64,15 @@ def prepare_model(self, load_model: bool = True) -> ov.Model:
logger.info("Provided IR model already contains XAI branch.")
self._model = self._model_ori
if load_model:
self.load_model()
self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name)
return self._model

xai_output_node = self.generate_xai_branch()
self._model = insert_xai_branch_into_model(self._model_ori, xai_output_node, self.embed_scaling)
if not has_xai(self._model):
raise RuntimeError("Insertion of the XAI branch into the model was not successful.")
if load_model:
self.load_model()
self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name)
return self._model

@staticmethod
Expand Down
Loading

0 comments on commit 8b0ddf9

Please sign in to comment.