From 46aaf6c57dd829fe93aa260b5a2999f963829ad7 Mon Sep 17 00:00:00 2001 From: GalyaZalesskaya Date: Fri, 27 Sep 2024 15:19:25 +0300 Subject: [PATCH] Fixes from comments --- openvino_xai/metrics/adcc.py | 2 +- openvino_xai/metrics/base.py | 2 +- tests/perf/conftest.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/openvino_xai/metrics/adcc.py b/openvino_xai/metrics/adcc.py index 7c4933c0..e881f793 100644 --- a/openvino_xai/metrics/adcc.py +++ b/openvino_xai/metrics/adcc.py @@ -20,7 +20,7 @@ class ADCC(BaseMetric): https://github.com/aimagelab/ADCC/ """ - def __init__(self, model, preprocess_fn, postprocess_fn, explainer, device_name="AUTO", **kwargs: Any): + def __init__(self, model, preprocess_fn, postprocess_fn, explainer, device_name="CPU", **kwargs: Any): super().__init__( model=model, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn, device_name=device_name ) diff --git a/openvino_xai/metrics/base.py b/openvino_xai/metrics/base.py index 1b53bb97..90b2fa8b 100644 --- a/openvino_xai/metrics/base.py +++ b/openvino_xai/metrics/base.py @@ -16,7 +16,7 @@ def __init__( model: ov.Model = None, preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), postprocess_fn: Callable[[np.ndarray], np.ndarray] = None, - device_name: str = "AUTO", + device_name: str = "CPU", ): # Pass model_predict to class initialization directly? self.model = model diff --git a/tests/perf/conftest.py b/tests/perf/conftest.py index b90d3cfe..2e387bc7 100644 --- a/tests/perf/conftest.py +++ b/tests/perf/conftest.py @@ -33,7 +33,7 @@ def pytest_addoption(parser: pytest.Parser): help="Number of masks for black box methods." "Defaults to 5000.", ) parser.addoption( - "--dataset-data-root", + "--dataset-root", action="store", default="", help="Path to directory with dataset images.", @@ -190,7 +190,7 @@ def fxt_perf_summary( @pytest.fixture(scope="session") def fxt_dataset_parameters(request: pytest.FixtureRequest) -> tuple[Path | None, Path | None]: """Retrieve dataset parameters for tests.""" - data_root = request.config.getoption("--dataset-data-root") + data_root = request.config.getoption("--dataset-root") ann_path = request.config.getoption("--dataset-ann-path") if data_root != "":