Skip to content

Commit

Permalink
Unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
negvet committed Aug 19, 2024
1 parent 208ccb7 commit 37bd496
Showing 1 changed file with 80 additions and 7 deletions.
87 changes: 80 additions & 7 deletions tests/unit/methods/black_box/test_black_box_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

from openvino_xai.common.utils import retrieve_otx_model
from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn
from openvino_xai.methods.black_box.aise import AISEClassification
from openvino_xai.methods.black_box.aise import AISEClassification, AISEDetection
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.methods.black_box.rise import RISE
from tests.intg.test_classification import DEFAULT_CLS_MODEL
from tests.intg.test_detection import DEFAULT_DET_MODEL


class InputSampling:
Expand All @@ -26,23 +27,41 @@ class InputSampling:
)
postprocess_fn = get_postprocess_fn()

def get_model(self, fxt_data_root):
def get_cls_model(self, fxt_data_root):
retrieve_otx_model(fxt_data_root, DEFAULT_CLS_MODEL)
model_path = fxt_data_root / "otx_models" / (DEFAULT_CLS_MODEL + ".xml")
return ov.Core().read_model(model_path)

def get_det_model(self, fxt_data_root):
detection_model = "det_yolox_bccd"
retrieve_otx_model(fxt_data_root, detection_model)
model_path = fxt_data_root / "otx_models" / (detection_model + ".xml")
return ov.Core().read_model(model_path)

def _generate_with_preset(self, method, preset):
_ = method.generate_saliency_map(
data=self.image,
target_indices=[1],
preset=preset,
)

@staticmethod
def preprocess_det_fn(x: np.ndarray) -> np.ndarray:
x = cv2.resize(src=x, dsize=(416, 416)) # OTX YOLOX
x = x.transpose((2, 0, 1))
x = np.expand_dims(x, 0)
return x

@staticmethod
def postprocess_det_fn(x) -> np.ndarray:
"""Returns boxes, scores, labels."""
# return x["boxes"][:, :4], x["boxes"][:, 4], x["labels"]
return x["boxes"][0][:, :4], x["boxes"][0][:, 4], x["labels"][0]

class TestAISE(InputSampling):
class TestAISEClassification(InputSampling):
@pytest.mark.parametrize("target_indices", [[0], [0, 1]])
def test_run(self, target_indices, fxt_data_root: Path):
model = self.get_model(fxt_data_root)
model = self.get_cls_model(fxt_data_root)

aise_method = AISEClassification(model, self.postprocess_fn, self.preprocess_fn)
saliency_map = aise_method.generate_saliency_map(
Expand Down Expand Up @@ -70,7 +89,7 @@ def test_run(self, target_indices, fxt_data_root: Path):
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)

def test_preset(self, fxt_data_root: Path):
model = self.get_model(fxt_data_root)
model = self.get_cls_model(fxt_data_root)
method = AISEClassification(model, self.postprocess_fn, self.preprocess_fn)

tic = time.time()
Expand All @@ -91,10 +110,64 @@ def test_preset(self, fxt_data_root: Path):
assert time_speed < time_balance < time_quality


class TestAISEDetection(InputSampling):
@pytest.mark.parametrize("target_indices", [[0], [0, 1]])
def test_run(self, target_indices, fxt_data_root: Path):
model = self.get_det_model(fxt_data_root)

aise_method = AISEDetection(model, self.postprocess_det_fn, self.preprocess_det_fn)
saliency_map = aise_method.generate_saliency_map(
data=self.image,
target_indices=target_indices,
preset=Preset.SPEED,
num_iterations_per_kernel=10,
divisors=[5],
)
assert aise_method.num_iterations_per_kernel == 10
assert aise_method.divisors == [5]

assert isinstance(saliency_map, dict)
assert len(saliency_map) == len(target_indices)
for target in target_indices:
assert target in saliency_map

ref_target = 0
assert saliency_map[ref_target].dtype == np.uint8
assert saliency_map[ref_target].shape == (416, 416)
assert (saliency_map[ref_target] >= 0).all() and (saliency_map[ref_target] <= 255).all()

tmp = saliency_map[0]

actual_sal_vals = saliency_map[0][150, 240:250].astype(np.int16)
ref_sal_vals = np.array([152, 168, 184, 199, 213, 225, 235, 243, 247, 249], dtype=np.uint8)
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)

def test_preset(self, fxt_data_root: Path):
model = self.get_det_model(fxt_data_root)
method = AISEDetection(model, self.postprocess_det_fn, self.preprocess_det_fn)

tic = time.time()
self._generate_with_preset(method, Preset.SPEED)
toc = time.time()
time_speed = toc - tic

tic = time.time()
self._generate_with_preset(method, Preset.BALANCE)
toc = time.time()
time_balance = toc - tic

tic = time.time()
self._generate_with_preset(method, Preset.QUALITY)
toc = time.time()
time_quality = toc - tic

assert time_speed < time_balance < time_quality


class TestRISE(InputSampling):
@pytest.mark.parametrize("target_indices", [[0], None])
def test_run(self, target_indices, fxt_data_root: Path):
model = self.get_model(fxt_data_root)
model = self.get_cls_model(fxt_data_root)

rise_method = RISE(model, self.postprocess_fn, self.preprocess_fn)
saliency_map = rise_method.generate_saliency_map(
Expand Down Expand Up @@ -123,7 +196,7 @@ def test_run(self, target_indices, fxt_data_root: Path):
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)

def test_preset(self, fxt_data_root: Path):
model = self.get_model(fxt_data_root)
model = self.get_cls_model(fxt_data_root)
method = RISE(model, self.postprocess_fn, self.preprocess_fn)

tic = time.time()
Expand Down

0 comments on commit 37bd496

Please sign in to comment.