Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move insert_xai into separate functional api module #11

Merged
merged 5 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openvino_xai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"""


from .api.api import insert_xai
from .common.parameters import Method, Task
from .explainer.explainer import Explainer
from .inserter import insert_xai

__all__ = ["Explainer", "insert_xai", "Method", "Task"]
10 changes: 10 additions & 0 deletions openvino_xai/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""
Finctional API.
"""
from openvino_xai.api.api import insert_xai

__all__ = [
"insert_xai",
]
51 changes: 51 additions & 0 deletions openvino_xai/api/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import openvino.runtime as ov

from openvino_xai.common.parameters import Task
from openvino_xai.common.utils import IdentityPreprocessFN, has_xai, logger
from openvino_xai.inserter.parameters import InsertionParameters
from openvino_xai.methods.factory import WhiteBoxMethodFactory


def insert_xai(
model: ov.Model,
task: Task,
insertion_parameters: InsertionParameters | None = None,
) -> ov.Model:
"""
Function that inserts XAI branch into IR.

Usage:
model_xai = openvino_xai.insert_xai(model, task=Task.CLASSIFICATION)

:param model: Original IR.
:type model: ov.Model | str
:param task: Type of the task: CLASSIFICATION or DETECTION.
:type task: Task
:param insertion_parameters: Insertion parameters that parametrize white-box method,
that will be inserted into the model graph (optional).
:type insertion_parameters: InsertionParameters
:return: IR with XAI branch.
"""

if has_xai(model):
logger.info("Provided IR model already contains XAI branch, return it as-is.")
return model

method = WhiteBoxMethodFactory.create_method(
task=task,
model=model,
preprocess_fn=IdentityPreprocessFN(),
insertion_parameters=insertion_parameters,
prepare_model=False,
)

model_xai = method.prepare_model(load_model=False)

if not has_xai(model_xai):
raise RuntimeError("Insertion of the XAI branch into the model was not successful.")
logger.info("Insertion of the XAI branch into the model was successful.")

return model_xai
2 changes: 0 additions & 2 deletions openvino_xai/inserter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
"""
Interface for inserting XAI branch into OV IR.
"""
from openvino_xai.inserter.inserter import insert_xai
from openvino_xai.inserter.parameters import (
ClassificationInsertionParameters,
DetectionInsertionParameters,
InsertionParameters,
)

__all__ = [
"insert_xai",
"InsertionParameters",
"ClassificationInsertionParameters",
"DetectionInsertionParameters",
Expand Down
52 changes: 1 addition & 51 deletions openvino_xai/inserter/inserter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,7 @@
import openvino.runtime as ov
from openvino.preprocess import PrePostProcessor

from openvino_xai import Task
from openvino_xai.common.utils import (
SALIENCY_MAP_OUTPUT_NAME,
IdentityPreprocessFN,
has_xai,
logger,
)
from openvino_xai.inserter.parameters import InsertionParameters


def insert_xai(
model: ov.Model,
task: Task,
insertion_parameters: InsertionParameters | None = None,
) -> ov.Model:
"""
Function that inserts XAI branch into IR.

Usage:
model_xai = openvino_xai.insert_xai(model, task=Task.CLASSIFICATION)

:param model: Original IR.
:type model: ov.Model | str
:param task: Type of the task: CLASSIFICATION or DETECTION.
:type task: Task
:param insertion_parameters: Insertion parameters that parametrize white-box method,
that will be inserted into the model graph (optional).
:type insertion_parameters: InsertionParameters
:return: IR with XAI branch.
"""
from openvino_xai.methods.factory import WhiteBoxMethodFactory

if has_xai(model):
logger.info("Provided IR model already contains XAI branch, return it as-is.")
return model

method = WhiteBoxMethodFactory.create_method(
task=task,
model=model,
preprocess_fn=IdentityPreprocessFN(),
insertion_parameters=insertion_parameters,
prepare_model=False,
)

model_xai = method.prepare_model(load_model=False)

if not has_xai(model_xai):
raise RuntimeError("Insertion of the XAI branch into the model was not successful.")
logger.info("Insertion of the XAI branch into the model was successful.")

return model_xai
from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME


def insert_xai_branch_into_model(
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import openvino.runtime as ov
import pytest

import openvino_xai as xai
import openvino_xai.api.api as xai
from openvino_xai.common.parameters import Method, Task
from openvino_xai.common.utils import has_xai, retrieve_otx_model
from openvino_xai.explainer.explainer import Explainer
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import openvino.runtime as ov

from openvino_xai.api.api import insert_xai
from openvino_xai.common.parameters import Task
from openvino_xai.common.utils import has_xai, retrieve_otx_model
from openvino_xai.inserter.inserter import insert_xai
from tests.integration.test_classification import DEFAULT_CLS_MODEL

DARA_DIR = Path(".data")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/explanation/test_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import openvino.runtime as ov
import pytest

from openvino_xai.api.api import insert_xai
from openvino_xai.common.parameters import Task
from openvino_xai.common.utils import retrieve_otx_model
from openvino_xai.explainer.explainer import Explainer
Expand All @@ -16,7 +17,6 @@
TargetExplainGroup,
)
from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn
from openvino_xai.inserter.inserter import insert_xai
from openvino_xai.inserter.parameters import ClassificationInsertionParameters
from tests.unit.explanation.test_explanation_utils import VOC_NAMES

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/insertion/test_insertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from openvino import runtime as ov

from openvino_xai import insert_xai
from openvino_xai.api.api import insert_xai
from openvino_xai.common.parameters import Method, Task
from openvino_xai.common.utils import has_xai, retrieve_otx_model
from openvino_xai.inserter.parameters import DetectionInsertionParameters
Expand Down
Loading