Skip to content

Commit

Permalink
fix run_detection
Browse files Browse the repository at this point in the history
  • Loading branch information
negvet committed Aug 20, 2024
1 parent 0099855 commit 23714d4
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions examples/run_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import openvino_xai as xai
from openvino_xai.common.utils import logger
from openvino_xai.explainer.explainer import ExplainMode
from openvino_xai.methods.black_box.base import Preset


def get_argument_parser():
Expand All @@ -24,8 +25,8 @@ def get_argument_parser():

def preprocess_fn(x: np.ndarray) -> np.ndarray:
# TODO: make sure it is correct
x = cv2.resize(src=x, dsize=(416, 416)) # OTX YOLOX
# x = cv2.resize(src=x, dsize=(992, 736)) # OTX ATSS
# x = cv2.resize(src=x, dsize=(416, 416)) # OTX YOLOX
x = cv2.resize(src=x, dsize=(992, 736)) # OTX ATSS
x = x.transpose((2, 0, 1))
x = np.expand_dims(x, 0)
return x
Expand All @@ -46,19 +47,19 @@ def explain_white_box(args):
model: ov.Model
model = ov.Core().read_model(args.model_path)

# OTX YOLOX
cls_head_output_node_names = [
"/bbox_head/multi_level_conv_cls.0/Conv/WithoutBiases",
"/bbox_head/multi_level_conv_cls.1/Conv/WithoutBiases",
"/bbox_head/multi_level_conv_cls.2/Conv/WithoutBiases",
]
# # OTX ATSS
# # OTX YOLOX
# cls_head_output_node_names = [
# "/bbox_head/atss_cls_1/Conv/WithoutBiases",
# "/bbox_head/atss_cls_2/Conv/WithoutBiases",
# "/bbox_head/atss_cls_3/Conv/WithoutBiases",
# "/bbox_head/atss_cls_4/Conv/WithoutBiases",
# "/bbox_head/multi_level_conv_cls.0/Conv/WithoutBiases",
# "/bbox_head/multi_level_conv_cls.1/Conv/WithoutBiases",
# "/bbox_head/multi_level_conv_cls.2/Conv/WithoutBiases",
# ]
# OTX ATSS
cls_head_output_node_names = [
"/bbox_head/atss_cls_1/Conv/WithoutBiases",
"/bbox_head/atss_cls_2/Conv/WithoutBiases",
"/bbox_head/atss_cls_3/Conv/WithoutBiases",
"/bbox_head/atss_cls_4/Conv/WithoutBiases",
]

# Create explainer object
explainer = xai.Explainer(
Expand Down Expand Up @@ -117,6 +118,7 @@ def explain_black_box(args):
image,
targets=[0], # target boxes to explain
overlay=True,
preset=Preset.SPEED,
)

logger.info(
Expand Down

0 comments on commit 23714d4

Please sign in to comment.