From 8b0ddf90ad24958c4d67db4388decddc22b68b65 Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Fri, 6 Sep 2024 08:07:02 +0900 Subject: [PATCH] Support Pytorch models for `insert_xai` API (#61) * Update build dependency with torch support * Apply generic interface for torch model support * Extend has_xai() for torch * Align unit test structure w/ src code * Add default __new__() for MethodBase * Copy torch XAI algos from OTX * Implement inserter for torch models * Implement ActivationMap torch hooks * Implement ReciproCAM torch hooks * Fix shape mismatch w/ optimize_gap=True * Implement ViTReciproCAM torch hooks * Enable torch method creation via factory * Add integration test * Refactor torch method class hierarchy --- .github/workflows/e2e.yml | 4 +- .github/workflows/pre_merge.yml | 2 +- openvino_xai/api/api.py | 17 +- openvino_xai/common/utils.py | 24 +- openvino_xai/inserter/inserter.py | 9 +- openvino_xai/methods/base.py | 33 ++- openvino_xai/methods/black_box/base.py | 4 +- openvino_xai/methods/factory.py | 31 +- .../methods/white_box/activation_map.py | 13 + openvino_xai/methods/white_box/base.py | 6 +- openvino_xai/methods/white_box/recipro_cam.py | 25 ++ openvino_xai/methods/white_box/torch.py | 264 ++++++++++++++++++ pyproject.toml | 3 +- tests/intg/test_accuracy_metrics.py | 2 +- tests/intg/test_classification_timm.py | 57 ++++ tests/unit/{explanation => api}/__init__.py | 0 .../test_insertion.py => api/test_api.py} | 0 tests/unit/common/test_utils.py | 16 ++ .../unit/{insertion => explainer}/__init__.py | 0 .../test_explainer.py | 2 +- .../test_explanation.py | 2 +- .../test_explanation_utils.py | 0 .../test_visualization.py | 0 tests/unit/inserter/__init__.py | 0 .../test_model_parser.py | 0 tests/unit/methods/test_factory.py | 38 +++ tests/unit/methods/white_box/test_torch.py | 172 ++++++++++++ tox.ini | 5 - 28 files changed, 665 insertions(+), 64 deletions(-) create mode 100644 openvino_xai/methods/white_box/torch.py rename tests/unit/{explanation => api}/__init__.py (100%) rename tests/unit/{insertion/test_insertion.py => api/test_api.py} (100%) rename tests/unit/{insertion => explainer}/__init__.py (100%) rename tests/unit/{explanation => explainer}/test_explainer.py (99%) rename tests/unit/{explanation => explainer}/test_explanation.py (98%) rename tests/unit/{explanation => explainer}/test_explanation_utils.py (100%) rename tests/unit/{explanation => explainer}/test_visualization.py (100%) create mode 100644 tests/unit/inserter/__init__.py rename tests/unit/{insertion => inserter}/test_model_parser.py (100%) create mode 100644 tests/unit/methods/white_box/test_torch.py diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 07507e19..aada08cd 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -31,11 +31,11 @@ jobs: - name: Install tox run: python -m pip install tox==4.4.6 - name: Run Functional Test - run: tox -vv -e val-py310 -- -v tests/func --csv=.tox/val-py310/func-test.csv -n 1 --max-worker-restart 100 --clear-cache + run: tox -vv -e dev-py310 -- -v tests/func --csv=.tox/dev-py310/func-test.csv -n 1 --max-worker-restart 100 --clear-cache - name: Upload artifacts uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 with: name: func-test-results - path: .tox/val-py310/*.csv + path: .tox/dev-py310/*.csv # Use always() to always run this step to publish test results when there are test failures if: ${{ always() }} diff --git a/.github/workflows/pre_merge.yml b/.github/workflows/pre_merge.yml index 93d03c2c..2e6adb09 100644 --- a/.github/workflows/pre_merge.yml +++ b/.github/workflows/pre_merge.yml @@ -100,7 +100,7 @@ jobs: - name: Install tox run: python -m pip install tox==4.4.6 - name: Run Integration Test - run: tox -vv -e val-py310 -- tests/intg --csv=.tox/dev-py310/intg-test.csv -n 1 --clear-cache + run: tox -vv -e dev-py310 -- tests/intg --csv=.tox/dev-py310/intg-test.csv -n 1 --clear-cache - name: Upload artifacts uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 with: diff --git a/openvino_xai/api/api.py b/openvino_xai/api/api.py index 5ea8c447..8b28ec4b 100644 --- a/openvino_xai/api/api.py +++ b/openvino_xai/api/api.py @@ -1,31 +1,34 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import List +from typing import List, TypeVar import openvino as ov +import torch from openvino_xai.common.parameters import Method, Task from openvino_xai.common.utils import IdentityPreprocessFN, has_xai, logger from openvino_xai.methods.factory import WhiteBoxMethodFactory +Model = TypeVar("Model", ov.Model, torch.nn.Module) + def insert_xai( - model: ov.Model, + model: Model, task: Task, explain_method: Method | None = None, target_layer: str | List[str] | None = None, embed_scaling: bool | None = True, **kwargs, -) -> ov.Model: +) -> Model: """ - Function that inserts XAI branch into IR. + Inserts XAI branch into the given model. Usage: model_xai = openvino_xai.insert_xai(model, task=Task.CLASSIFICATION) - :param model: Original IR. - :type model: ov.Model | str + :param model: Original model. + :type model: ov.Model | torch.nn.Module :param task: Type of the task: CLASSIFICATION or DETECTION. :type task: Task :parameter explain_method: Explain method to use for model explanation. @@ -37,7 +40,7 @@ def insert_xai( """ if has_xai(model): - logger.info("Provided IR model already contains XAI branch, return it as-is.") + logger.info("Provided model already contains XAI branch, return it as-is.") return model method = WhiteBoxMethodFactory.create_method( diff --git a/openvino_xai/common/utils.py b/openvino_xai/common/utils.py index 4a3e213f..95441048 100644 --- a/openvino_xai/common/utils.py +++ b/openvino_xai/common/utils.py @@ -10,7 +10,8 @@ from urllib.request import urlretrieve import numpy as np -import openvino.runtime as ov +import openvino as ov +import torch logger = logging.getLogger("openvino_xai") handler = logging.StreamHandler() @@ -23,20 +24,23 @@ SALIENCY_MAP_OUTPUT_NAME = "saliency_map" -def has_xai(model: ov.Model) -> bool: +def has_xai(model: ov.Model | torch.nn.Module) -> bool: """ Function checks if the model contains XAI branch. - :param model: OV IR model. - :type model: ov.Model + :param model: Input model for inspect. + :type model: ov.Model | torch.nn.Module :return: True is the model has XAI branch and saliency_map output, False otherwise. """ - if not isinstance(model, ov.Model): - raise ValueError(f"Input model has to be ov.Model instance, but got{type(model)}.") - for output in model.outputs: - if SALIENCY_MAP_OUTPUT_NAME in output.get_names(): - return True - return False + if isinstance(model, ov.Model): + for output in model.outputs: + if SALIENCY_MAP_OUTPUT_NAME in output.get_names(): + return True + return False + elif isinstance(model, torch.nn.Module): + return getattr(model, "has_xai", False) + else: + raise ValueError(f"Input model has to be openvino.Model or torch.nn.Module instance, but got{type(model)}.") # Not a part of product diff --git a/openvino_xai/inserter/inserter.py b/openvino_xai/inserter/inserter.py index b7f5c1d4..6353662c 100644 --- a/openvino_xai/inserter/inserter.py +++ b/openvino_xai/inserter/inserter.py @@ -1,7 +1,8 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import openvino.runtime as ov + +import openvino as ov from openvino.preprocess import PrePostProcessor from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME @@ -9,10 +10,10 @@ def insert_xai_branch_into_model( model: ov.Model, - xai_output_node, - set_uint8, + xai_output_node: ov.runtime.Node, + set_uint8: bool, ) -> ov.Model: - """Creates new model with XAI branch.""" + """Create new model with XAI branch.""" model_ori_outputs = model.outputs model_ori_params = model.get_parameters() model_xai = ov.Model([*model_ori_outputs, xai_output_node.output(0)], model_ori_params) diff --git a/openvino_xai/methods/base.py b/openvino_xai/methods/base.py index 59033fe9..cb385ac0 100644 --- a/openvino_xai/methods/base.py +++ b/openvino_xai/methods/base.py @@ -3,21 +3,38 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Callable, Dict, List, Mapping, Tuple +from typing import Callable, Dict, Generic, List, Mapping, Tuple, TypeAlias, TypeVar import numpy as np import openvino as ov +import torch from openvino_xai.common.utils import IdentityPreprocessFN +Model = TypeVar("Model", ov.Model, torch.nn.Module) +CompiledModel = TypeVar("CompiledModel", ov.CompiledModel, torch.nn.Module) +PreprocessFn: TypeAlias = Callable[[np.ndarray], np.ndarray] -class MethodBase(ABC): + +class MethodBase(ABC, Generic[Model, CompiledModel]): """Base class for XAI methods.""" + def __new__( + cls, + model: Model | None = None, + *args, + **kwargs, + ): + if isinstance(model, torch.nn.Module): + raise NotImplementedError(f"{type(model)} is not yet supported for {cls}") + elif model is not None and not isinstance(model, ov.Model): + raise ValueError(f"{type(model)} is not supported") + return super().__new__(cls) + def __init__( self, - model: ov.Model = None, - preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), + model: Model | None = None, + preprocess_fn: PreprocessFn = IdentityPreprocessFN(), device_name: str = "CPU", ): self._model = model @@ -27,11 +44,11 @@ def __init__( self.predictions: Dict[int, Prediction] = {} @property - def model_compiled(self) -> ov.CompiledModel | None: + def model_compiled(self) -> CompiledModel | None: return self._model_compiled @abstractmethod - def prepare_model(self, load_model: bool = True) -> ov.Model: + def prepare_model(self, load_model: bool = True) -> Model: """Model preparation steps.""" def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: @@ -46,10 +63,6 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np.ndarray: """Saliency map generation.""" - def load_model(self) -> None: - core = ov.Core() - self._model_compiled = core.compile_model(model=self._model, device_name=self._device_name) - @dataclass class Prediction: diff --git a/openvino_xai/methods/black_box/base.py b/openvino_xai/methods/black_box/base.py index 16fe12d3..9a311b09 100644 --- a/openvino_xai/methods/black_box/base.py +++ b/openvino_xai/methods/black_box/base.py @@ -12,7 +12,7 @@ from openvino_xai.methods.black_box.utils import check_classification_output -class BlackBoxXAIMethod(MethodBase): +class BlackBoxXAIMethod(MethodBase[ov.Model, ov.CompiledModel]): """Base class for methods that explain model in Black-Box mode.""" def __init__( @@ -28,7 +28,7 @@ def __init__( def prepare_model(self, load_model: bool = True) -> ov.Model: """Load model prior to inference.""" if load_model: - self.load_model() + self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name) return self._model def get_logits(self, data_preprocessed: np.ndarray) -> np.ndarray: diff --git a/openvino_xai/methods/factory.py b/openvino_xai/methods/factory.py index c3445ece..36229ffd 100644 --- a/openvino_xai/methods/factory.py +++ b/openvino_xai/methods/factory.py @@ -5,7 +5,8 @@ from typing import Callable, List, Mapping import numpy as np -import openvino.runtime as ov +import openvino as ov +import torch from openvino_xai.common.parameters import Method, Task from openvino_xai.common.utils import IdentityPreprocessFN, logger @@ -44,7 +45,7 @@ class WhiteBoxMethodFactory(MethodFactory): def create_method( cls, task: Task, - model: ov.Model, + model: ov.Model | torch.nn.Module, preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), explain_method: Method | None = None, target_layer: str | List[str] | None = None, @@ -72,11 +73,11 @@ def create_method( device_name, **kwargs, ) - raise ValueError(f"Model type {task} is not supported in white-box mode.") + raise ValueError(f"Task type {task} is not supported in white-box mode.") @staticmethod def create_classification_method( - model: ov.Model, + model: ov.Model | torch.nn.Module, preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), explain_method: Method | None = None, target_layer: str | None = None, @@ -86,8 +87,8 @@ def create_classification_method( ) -> WhiteBoxMethod: """Generates instance of the classification white-box method class. - :param model: OV IR model. - :type model: ov.Model + :param model: Input model. + :type model: ov.Model | torch.nn.Module :param preprocess_fn: Preprocessing function, identity function by default (assume input images are already preprocessed by user). :type preprocess_fn: Callable[[np.ndarray], np.ndarray] @@ -147,7 +148,7 @@ def create_classification_method( @staticmethod def create_detection_method( - model: ov.Model, + model: ov.Model | torch.nn.Module, preprocess_fn: Callable[[np.ndarray], np.ndarray], explain_method: Method | None = None, target_layer: List[str] | None = None, @@ -157,8 +158,8 @@ def create_detection_method( ) -> WhiteBoxMethod: """Generates instance of the detection white-box method class. - :param model: OV IR model. - :type model: ov.Model + :param model: Input model. + :type model: ov.Model | torch.nn.Module :param preprocess_fn: Preprocessing function, identity function by default (assume input images are already preprocessed by user). :type preprocess_fn: Callable[[np.ndarray], np.ndarray] @@ -190,7 +191,7 @@ class BlackBoxMethodFactory(MethodFactory): def create_method( cls, task: Task, - model: ov.Model, + model: ov.Model | torch.nn.Module, postprocess_fn: Callable[[Mapping], np.ndarray], preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), explain_method: Method | None = None, @@ -209,7 +210,7 @@ def create_method( @staticmethod def create_classification_method( - model: ov.Model, + model: ov.Model | torch.nn.Module, postprocess_fn: Callable[[Mapping], np.ndarray], preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), explain_method: Method | None = None, @@ -219,9 +220,9 @@ def create_classification_method( """Generates instance of the classification black-box method class. Using AISE as a default method. - :param model: OV IR model. - :type model: ov.Model - :param postprocess_fn: Preprocessing function that extract scores from IR model output. + :param model: Input model. + :type model: ov.Model | torch.nn.Module + :param postprocess_fn: Preprocessing function that extract scores from model output. :type postprocess_fn: Callable[[Mapping], np.ndarray] :param preprocess_fn: Preprocessing function, identity function by default (assume input images are already preprocessed by user). @@ -237,7 +238,7 @@ def create_classification_method( @staticmethod def create_detection_method( - model: ov.Model, + model: ov.Model | torch.nn.Module, postprocess_fn: Callable[[Mapping], np.ndarray], preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), explain_method: Method | None = None, diff --git a/openvino_xai/methods/white_box/activation_map.py b/openvino_xai/methods/white_box/activation_map.py index b41413c4..a45fea60 100644 --- a/openvino_xai/methods/white_box/activation_map.py +++ b/openvino_xai/methods/white_box/activation_map.py @@ -5,6 +5,7 @@ import numpy as np import openvino.runtime as ov +import torch from openvino.runtime import opset10 as opset from openvino_xai.common.utils import IdentityPreprocessFN @@ -31,6 +32,18 @@ class ActivationMap(WhiteBoxMethod): :type prepare_model: bool """ + def __new__( + cls, + model: ov.Model | torch.nn.Module | None = None, + *args, + **kwargs, + ): + if isinstance(model, torch.nn.Module): + from .torch import TorchActivationMap + + return TorchActivationMap(model, *args, **kwargs) + return super().__new__(cls) + def __init__( self, model: ov.Model, diff --git a/openvino_xai/methods/white_box/base.py b/openvino_xai/methods/white_box/base.py index 9823e310..e7a8e10e 100644 --- a/openvino_xai/methods/white_box/base.py +++ b/openvino_xai/methods/white_box/base.py @@ -19,7 +19,7 @@ from openvino_xai.methods.base import MethodBase -class WhiteBoxMethod(MethodBase): +class WhiteBoxMethod(MethodBase[ov.Model, ov.CompiledModel]): """ Base class for white-box XAI methods. @@ -64,7 +64,7 @@ def prepare_model(self, load_model: bool = True) -> ov.Model: logger.info("Provided IR model already contains XAI branch.") self._model = self._model_ori if load_model: - self.load_model() + self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name) return self._model xai_output_node = self.generate_xai_branch() @@ -72,7 +72,7 @@ def prepare_model(self, load_model: bool = True) -> ov.Model: if not has_xai(self._model): raise RuntimeError("Insertion of the XAI branch into the model was not successful.") if load_model: - self.load_model() + self._model_compiled = ov.Core().compile_model(model=self._model, device_name=self._device_name) return self._model @staticmethod diff --git a/openvino_xai/methods/white_box/recipro_cam.py b/openvino_xai/methods/white_box/recipro_cam.py index da516125..e5ae6f0b 100644 --- a/openvino_xai/methods/white_box/recipro_cam.py +++ b/openvino_xai/methods/white_box/recipro_cam.py @@ -7,6 +7,7 @@ import numpy as np import openvino.runtime as ov +import torch from openvino.runtime import opset10 as opset from openvino_xai.common.utils import IdentityPreprocessFN @@ -78,6 +79,18 @@ class ReciproCAM(FeatureMapPerturbationBase): :type prepare_model: bool """ + def __new__( + cls, + model: ov.Model | torch.nn.Module | None = None, + *args, + **kwargs, + ): + if isinstance(model, torch.nn.Module): + from .torch import TorchReciproCAM + + return TorchReciproCAM(model, *args, **kwargs) + return super().__new__(cls) + def __init__( self, model: ov.Model, @@ -171,6 +184,18 @@ class ViTReciproCAM(FeatureMapPerturbationBase): :type prepare_model: bool """ + def __new__( + cls, + model: ov.Model | torch.nn.Module | None = None, + *args, + **kwargs, + ): + if isinstance(model, torch.nn.Module): + from .torch import TorchViTReciproCAM + + return TorchViTReciproCAM(model, *args, **kwargs) + return super().__new__(cls) + def __init__( self, model: ov.Model, diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py new file mode 100644 index 00000000..a1e8e80f --- /dev/null +++ b/openvino_xai/methods/white_box/torch.py @@ -0,0 +1,264 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copy & edit from https://github.com/openvinotoolkit/training_extensions/blob/2.1.0/src/otx/algo/explain/explain_algo.py +"""Algorithms for calculcalating XAI branch for Explainable AI.""" + +import copy +from typing import Any, Callable, Dict, Mapping + +import numpy as np +import torch + +from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME, has_xai +from openvino_xai.methods.base import IdentityPreprocessFN, MethodBase + + +class TorchWhiteBoxMethod(MethodBase[torch.nn.Module, torch.nn.Module]): + """ + Base class for Torch-based methods. + + :param model: Input model. + :type model: torch.nn.Module + :param preprocess_fn: Preprocessing function, identity function by default + (assume input images are already preprocessed by user). + :type preprocess_fn: Callable[[np.ndarray], np.ndarray] + :parameter target_layer: Target layer (node) name after which the XAI branch will be inserted. + :type target_layer: str + :param embed_scaling: Whether to scale output or not. + :type embed_scaling: bool + :param device_name: Device type name. + :type device_name: str + """ + + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + + def __init__( + self, + model: torch.nn.Module, + preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), + target_layer: str | None = None, + embed_scaling: bool = True, + device_name: str = "CPU", + **kwargs, + ): + super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name) + self._target_layer = target_layer + self._embed_scaling = embed_scaling + + def prepare_model(self, load_model: bool = True) -> torch.nn.Module: + """Return XAI inserted model.""" + if has_xai(self._model): + if load_model: + self._model_compiled = self._model + return self._model + + model = copy.deepcopy(self._model) + # Feature + feature_layer = model.get_submodule(self._target_layer) + feature_layer.register_forward_hook(self._feature_hook) + # Output + model.register_forward_hook(self._output_hook) + setattr(model, "has_xai", True) + model.eval() + + if load_model: + self._model_compiled = model + return model + + def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: + """Process numpy input, return numpy output.""" + if not self._model_compiled: + raise RuntimeError("Model is not compiled. Call prepare_model() first.") + + if preprocess: + x = self.preprocess_fn(x) + x = torch.from_numpy(x).float() + + with torch.no_grad(): + x = self._model_compiled(x) + + output = {} + for name, data in x.items(): + if not isinstance(data, torch.Tensor): + data = torch.tensor(data) + output[name] = data.numpy(force=True) + return output + + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: + self._feature_map = output + return output + + def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: + return { + "prediction": output, + SALIENCY_MAP_OUTPUT_NAME: torch.empty_like(output), + } + + def generate_saliency_map(self, data: np.ndarray) -> np.ndarray: + """Return saliency map.""" + model_output = self.model_forward(data) + return model_output[SALIENCY_MAP_OUTPUT_NAME] + + @staticmethod + def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor: + """Normalize saliency maps.""" + max_values = saliency_map.max(dim=-1, keepdim=True).values + min_values = saliency_map.min(dim=-1, keepdim=True).values + saliency_map = 255 * (saliency_map - min_values) / (max_values - min_values + 1e-12) + return saliency_map.to(torch.uint8) + + +class TorchActivationMap(TorchWhiteBoxMethod): + """ActivationMap. Mean of the feature map along the channel dimension.""" + + def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: + feature_map = self._feature_map + batch_size, _, h, w = feature_map.shape + activation_map = torch.mean(feature_map, dim=1) + if self._embed_scaling: + activation_map = activation_map.reshape((batch_size, h * w)) + activation_map = self._normalize_map(activation_map) + activation_map = activation_map.reshape((batch_size, h, w)) + return { + "prediction": output, + SALIENCY_MAP_OUTPUT_NAME: activation_map, + } + + +class TorchReciproCAM(TorchWhiteBoxMethod): + """Implementation of Recipro-CAM for class-wise saliency map. + + Recipro-CAM: gradient-free reciprocal class activation map (https://arxiv.org/pdf/2209.14074.pdf) + + :param optimize_gap: Whether to optimize out Global Average Pooling operation + :type optimizae_gap: bool + """ + + def __init__(self, *args, optimize_gap: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self._optimize_gap = optimize_gap + + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: + """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" + batch_size, c, h, w = self._feature_shape = output.shape + feature_map = output + if self._optimize_gap: + feature_map = feature_map.reshape([batch_size, c, h * w]).mean(dim=-1)[:, :, None, None] # Spatial average + feature_maps = [feature_map] + for i in range(batch_size): + mosaic_feature_map = self._get_mosaic_feature_map(output[i], c, h, w) + feature_maps.append(mosaic_feature_map) + return torch.cat(feature_maps) + + def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: + batch_size, _, h, w = self._feature_shape + num_classes = output.shape[1] + predictions = output[:batch_size] + saliency_maps = output[batch_size:] + saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) + saliency_maps = saliency_maps.transpose(1, 2) # BxHWxC -> BxCxHW + if self._embed_scaling: + saliency_maps = saliency_maps.reshape((batch_size * num_classes, h * w)) + saliency_maps = self._normalize_map(saliency_maps) + saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) + return { + "prediction": predictions, + SALIENCY_MAP_OUTPUT_NAME: saliency_maps, + } + + def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w: int) -> torch.Tensor: + if self._optimize_gap: + # if isinstance(model_neck, GlobalAveragePooling): + # Optimization workaround for the GAP case (simulate GAP with more simple compute graph) + # Possible due to static sparsity of mosaic_feature_map + # Makes the downstream GAP operation to be dummy + feature_map_transposed = torch.flatten(feature_map, start_dim=1).transpose(0, 1)[:, :, None, None] + mosaic_feature_map = feature_map_transposed / (h * w) + else: + feature_map_repeated = feature_map.repeat(h * w, 1, 1, 1) + mosaic_feature_map_mask = torch.zeros(h * w, c, h, w).to(feature_map.device) + spatial_order = torch.arange(h * w).reshape(h, w) + for i in range(h): + for j in range(w): + k = spatial_order[i, j] + mosaic_feature_map_mask[k, :, i, j] = torch.ones(c).to(feature_map.device) + mosaic_feature_map = feature_map_repeated * mosaic_feature_map_mask + return mosaic_feature_map + + +class TorchViTReciproCAM(TorchReciproCAM): + """Implementation of ViTRecipro-CAM for class-wise saliency map for transformer-based classifiers. + + ViT-ReciproCAM: Gradient and Attention-Free Visual Explanations for Vision Transformer + (https://arxiv.org/abs/2310.02588) + + :param use_gaussian: Defines kernel type for mosaic feature map generation. + If True, use gaussian 3x3 kernel. If False, use 1x1 kernel. + :type use_gaussian: bool + :param use_cls_token: If True, includes classification token into the mosaic feature map. + :type use_cls_token: bool + """ + + def __init__( + self, + *args, + use_gaussian: bool = True, + use_cls_token: bool = True, + normalize: bool = True, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self._use_gaussian = use_gaussian + self._use_cls_token = use_cls_token + + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: + """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" + feature_map = output + batch_size, num_tokens, dim = feature_map.shape + h = w = int((num_tokens - 1) ** 0.5) + feature_maps = [feature_map] + self._feature_shape = (batch_size, dim, h, w) + for i in range(batch_size): + mosaic_feature_map = self._get_mosaic_feature_map(feature_map[i], dim, h, w) + feature_maps.append(mosaic_feature_map) + return torch.cat(feature_maps) + + def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w: int) -> torch.Tensor: + num_tokens = h * w + 1 + mosaic_feature_map = torch.zeros(h * w, num_tokens, c).to(feature_map.device) + + if self._use_gaussian: + if self._use_cls_token: + mosaic_feature_map[:, 0, :] = feature_map[0, :] + feature_map_spatial = feature_map[1:, :].reshape(1, h, w, c) + feature_map_spatial_repeated = feature_map_spatial.repeat(h * w, 1, 1, 1) # 196, 14, 14, 192 + + spatial_order = torch.arange(h * w).reshape(h, w) + gaussian = torch.tensor( + [[1 / 16.0, 1 / 8.0, 1 / 16.0], [1 / 8.0, 1 / 4.0, 1 / 8.0], [1 / 16.0, 1 / 8.0, 1 / 16.0]], + ).to(feature_map.device) + mosaic_feature_map_mask_padded = torch.zeros(h * w, h + 2, w + 2).to(feature_map.device) + for i in range(h): + for j in range(w): + k = spatial_order[i, j] + i_pad = i + 1 + j_pad = j + 1 + mosaic_feature_map_mask_padded[k, i_pad - 1 : i_pad + 2, j_pad - 1 : j_pad + 2] = gaussian + mosaic_feature_map_mask = mosaic_feature_map_mask_padded[:, 1:-1, 1:-1] + mosaic_feature_map_mask = mosaic_feature_map_mask.unsqueeze(3).repeat(1, 1, 1, c) + + mosaic_fm_wo_cls_token = feature_map_spatial_repeated * mosaic_feature_map_mask + mosaic_feature_map[:, 1:, :] = mosaic_fm_wo_cls_token.reshape(h * w, h * w, c) + else: + feature_map_repeated = feature_map.unsqueeze(0).repeat(h * w, 1, 1) + mosaic_feature_map_mask = torch.zeros(h * w, num_tokens).to(feature_map.device) + for i in range(h * w): + mosaic_feature_map_mask[i, i + 1] = torch.ones(1).to(feature_map.device) + if self._use_cls_token: + mosaic_feature_map_mask[:, 0] = torch.ones(1).to(feature_map.device) + mosaic_feature_map_mask = mosaic_feature_map_mask.unsqueeze(2).repeat(1, 1, c) + mosaic_feature_map = feature_map_repeated * mosaic_feature_map_mask + + return mosaic_feature_map diff --git a/pyproject.toml b/pyproject.toml index 8affae50..edb75e20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "numpy==1.*", "tqdm", "matplotlib", + "torch", ] requires-python = ">=3.10" authors = [ @@ -42,8 +43,6 @@ dev = [ "pytest-xdist", "pre-commit==3.7.0", "addict", -] -val = [ "timm==0.9.5", "onnx==1.14.1", "pandas", diff --git a/tests/intg/test_accuracy_metrics.py b/tests/intg/test_accuracy_metrics.py index e75f92bd..8a98c993 100644 --- a/tests/intg/test_accuracy_metrics.py +++ b/tests/intg/test_accuracy_metrics.py @@ -18,7 +18,7 @@ get_preprocess_fn, ) from openvino_xai.metrics import ADCC, InsertionDeletionAUC, PointingGame -from tests.unit.explanation.test_explanation_utils import VOC_NAMES +from tests.unit.explainer.test_explanation_utils import VOC_NAMES MODEL_NAME = "mlc_mobilenetv3_large_voc" IMAGE_PATH = "tests/assets/cheetah_person.jpg" diff --git a/tests/intg/test_classification_timm.py b/tests/intg/test_classification_timm.py index d2fb3d47..d7667d61 100644 --- a/tests/intg/test_classification_timm.py +++ b/tests/intg/test_classification_timm.py @@ -11,6 +11,7 @@ import openvino as ov import pytest +from openvino_xai import insert_xai from openvino_xai.common.parameters import Method, Task from openvino_xai.explainer.explainer import Explainer, ExplainMode from openvino_xai.explainer.utils import ( @@ -363,6 +364,7 @@ def test_ovc_model_white_box(self, model_id): assert explanation is not None assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1 print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.") + self.clear_cache() @pytest.mark.parametrize( "model_id", @@ -441,6 +443,61 @@ def test_model_format(self, model_id, explain_mode, model_format): assert explanation is not None assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1 print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.") + self.clear_cache() + + @pytest.mark.parametrize( + "model_id", + [ + "resnet18.a1_in1k", + "efficientnet_b0.ra_in1k", + "vit_tiny_patch16_224.augreg_in21k", + "deit_tiny_patch16_224.fb_in1k", + ], + ) + def test_torch_insert_xai_with_layer(self, model_id: str): + xai_cfg = { + "resnet18.a1_in1k": ("layer4", Method.RECIPROCAM), + "efficientnet_b0.ra_in1k": ("bn2", Method.RECIPROCAM), + "vit_tiny_patch16_224.augreg_in21k": ("blocks.9.norm1", Method.VITRECIPROCAM), + "deit_tiny_patch16_224.fb_in1k": ("blocks.9.norm1", Method.VITRECIPROCAM), + } + + model_dir = self.data_dir / "timm_models" / "converted_models" + model, model_cfg = self.get_timm_model(model_id, model_dir) + + image = cv2.imread("tests/assets/cheetah_person.jpg") + image = cv2.resize(image, dsize=model_cfg["input_size"][1:]) + image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB) + mean = np.array(model.default_cfg["mean"]) + std = np.array(model.default_cfg["std"]) + image_norm = (image / 255.0 - mean) / std + image_norm = image_norm.transpose((2, 0, 1)) # HWC -> CHW + image_norm = image_norm[None, :] # CHW -> 1CHW + target_class = self.supported_num_classes[model_cfg["num_classes"]] + + xai_model: torch.nn.Module = insert_xai( + model, + task=Task.CLASSIFICATION, + target_layer=xai_cfg[model_id][0], + explain_method=xai_cfg[model_id][1], + ) + + with torch.no_grad(): + xai_model.eval() + xai_output = xai_model(torch.from_numpy(image_norm).float()) + xai_logit = xai_output["prediction"] + xai_prob = torch.softmax(xai_logit, dim=-1) + xai_label = xai_prob.argmax(dim=-1)[0] + assert xai_label.item() == target_class + assert xai_prob[0, xai_label].item() > 0.0 + + saliency_map: np.ndarray = xai_output["saliency_map"].numpy(force=True) + saliency_map = saliency_map.squeeze(0) + assert saliency_map.shape[-1] > 1 and saliency_map.shape[-2] > 1 + assert saliency_map.min() < saliency_map.max() + assert saliency_map.dtype == np.uint8 + + self.clear_cache() def check_for_saved_map(self, model_id, directory): for target in self.supported_num_classes.values(): diff --git a/tests/unit/explanation/__init__.py b/tests/unit/api/__init__.py similarity index 100% rename from tests/unit/explanation/__init__.py rename to tests/unit/api/__init__.py diff --git a/tests/unit/insertion/test_insertion.py b/tests/unit/api/test_api.py similarity index 100% rename from tests/unit/insertion/test_insertion.py rename to tests/unit/api/test_api.py diff --git a/tests/unit/common/test_utils.py b/tests/unit/common/test_utils.py index f6aa3468..20f5a0df 100644 --- a/tests/unit/common/test_utils.py +++ b/tests/unit/common/test_utils.py @@ -4,6 +4,9 @@ from pathlib import Path import openvino as ov +import pytest +import torch +from pytest_mock import MockerFixture from openvino_xai.api.api import insert_xai from openvino_xai.common.parameters import Task @@ -12,6 +15,7 @@ def test_has_xai(fxt_data_root: Path): + # OV model_without_xai = DEFAULT_CLS_MODEL retrieve_otx_model(fxt_data_root, model_without_xai) model_path = fxt_data_root / "otx_models" / (model_without_xai + ".xml") @@ -25,3 +29,15 @@ def test_has_xai(fxt_data_root: Path): ) assert has_xai(model_xai) + + # Torch + model = torch.nn.Module() + assert has_xai(model) == False + model.has_xai = True + assert has_xai(model) == True + + # Other + with pytest.raises(ValueError): + has_xai(None) + with pytest.raises(ValueError): + has_xai(object) diff --git a/tests/unit/insertion/__init__.py b/tests/unit/explainer/__init__.py similarity index 100% rename from tests/unit/insertion/__init__.py rename to tests/unit/explainer/__init__.py diff --git a/tests/unit/explanation/test_explainer.py b/tests/unit/explainer/test_explainer.py similarity index 99% rename from tests/unit/explanation/test_explainer.py rename to tests/unit/explainer/test_explainer.py index bdeceea6..c0cad11e 100644 --- a/tests/unit/explanation/test_explainer.py +++ b/tests/unit/explainer/test_explainer.py @@ -15,7 +15,7 @@ from openvino_xai.explainer.explainer import Explainer, ExplainMode from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn from openvino_xai.methods.black_box.base import Preset -from tests.unit.explanation.test_explanation_utils import VOC_NAMES +from tests.unit.explainer.test_explanation_utils import VOC_NAMES MODEL_NAME = "mlc_mobilenetv3_large_voc" diff --git a/tests/unit/explanation/test_explanation.py b/tests/unit/explainer/test_explanation.py similarity index 98% rename from tests/unit/explanation/test_explanation.py rename to tests/unit/explainer/test_explanation.py index c72e233d..fed82447 100644 --- a/tests/unit/explanation/test_explanation.py +++ b/tests/unit/explainer/test_explanation.py @@ -9,7 +9,7 @@ from openvino_xai.common.parameters import Task from openvino_xai.explainer.explanation import Explanation -from tests.unit.explanation.test_explanation_utils import VOC_NAMES +from tests.unit.explainer.test_explanation_utils import VOC_NAMES SALIENCY_MAPS = (np.random.rand(1, 20, 5, 5) * 255).astype(np.uint8) SALIENCY_MAPS_IMAGE = (np.random.rand(1, 5, 5) * 255).astype(np.uint8) diff --git a/tests/unit/explanation/test_explanation_utils.py b/tests/unit/explainer/test_explanation_utils.py similarity index 100% rename from tests/unit/explanation/test_explanation_utils.py rename to tests/unit/explainer/test_explanation_utils.py diff --git a/tests/unit/explanation/test_visualization.py b/tests/unit/explainer/test_visualization.py similarity index 100% rename from tests/unit/explanation/test_visualization.py rename to tests/unit/explainer/test_visualization.py diff --git a/tests/unit/inserter/__init__.py b/tests/unit/inserter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/insertion/test_model_parser.py b/tests/unit/inserter/test_model_parser.py similarity index 100% rename from tests/unit/insertion/test_model_parser.py rename to tests/unit/inserter/test_model_parser.py diff --git a/tests/unit/methods/test_factory.py b/tests/unit/methods/test_factory.py index b197c1c2..dac63211 100644 --- a/tests/unit/methods/test_factory.py +++ b/tests/unit/methods/test_factory.py @@ -5,6 +5,7 @@ import openvino as ov import pytest +import torch from pytest_mock import MockerFixture from openvino_xai.common.parameters import Method, Task @@ -12,6 +13,7 @@ from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn from openvino_xai.methods.black_box.aise.classification import AISEClassification from openvino_xai.methods.factory import BlackBoxMethodFactory, WhiteBoxMethodFactory +from openvino_xai.methods.white_box import torch as torch_method from openvino_xai.methods.white_box.activation_map import ActivationMap from openvino_xai.methods.white_box.det_class_probability_map import ( DetClassProbabilityMap, @@ -147,3 +149,39 @@ def test_create_wb_det_cnn_method(fxt_data_root: Path): saliency_map_size=sal_map_size, ) assert str(exc_info.value) == "Requested explanation method abc is not implemented." + + +def test_create_torch_method(): + model = {} + with pytest.raises(ValueError): + explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn()) + model = torch.nn.Module() + with pytest.raises(NotImplementedError): + explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn()) + with pytest.raises(NotImplementedError): + explain_method = BlackBoxMethodFactory.create_method( + Task.DETECTION, model, get_postprocess_fn(), target_layer="" + ) + + model = {} + with pytest.raises(ValueError): + explain_method = WhiteBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn()) + model = torch.nn.Module() + with pytest.raises(NotImplementedError): + explain_method = WhiteBoxMethodFactory.create_method( + Task.DETECTION, model, get_postprocess_fn(), target_layer="" + ) + + model = torch.nn.Module() + explain_method = WhiteBoxMethodFactory.create_method( + Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.ACTIVATIONMAP + ) + assert isinstance(explain_method, torch_method.TorchActivationMap) + explain_method = WhiteBoxMethodFactory.create_method( + Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.RECIPROCAM + ) + assert isinstance(explain_method, torch_method.TorchReciproCAM) + explain_method = WhiteBoxMethodFactory.create_method( + Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.VITRECIPROCAM + ) + assert isinstance(explain_method, torch_method.TorchViTReciproCAM) diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py new file mode 100644 index 00000000..fa374f6c --- /dev/null +++ b/tests/unit/methods/white_box/test_torch.py @@ -0,0 +1,172 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +# Copy & edit from https://github.com/openvinotoolkit/training_extensions/blob/2.1.0/tests/unit/algo/explain/test_xai_algorithms.py + +from typing import Any, Callable, Dict, Mapping, Sequence, TypeAlias + +import numpy as np +import pytest +import torch + +from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME, has_xai +from openvino_xai.methods.white_box.torch import ( + TorchActivationMap, + TorchReciproCAM, + TorchViTReciproCAM, + TorchWhiteBoxMethod, +) + + +def test_normalize(): + x = torch.rand((2, 2)) * 100 + y = TorchWhiteBoxMethod._normalize_map(x) + assert x.shape == y.shape + assert torch.all(y >= 0) + assert torch.all(y <= 255) + x = torch.rand((2, 2, 2)) * 100 + y = TorchWhiteBoxMethod._normalize_map(x) + assert x.shape == y.shape + assert torch.all(y >= 0) + assert torch.all(y <= 255) + + +class DummyCNN(torch.nn.Module): + def __init__(self, num_classes: int = 2): + super().__init__() + self.num_classes = num_classes + self.feature = torch.nn.Identity() + self.neck = torch.nn.AdaptiveAvgPool2d((1, 1)) + self.output = torch.nn.LazyLinear(out_features=num_classes) + + def forward(self, x: torch.Tensor): + x = self.feature(x) + x = self.neck(x) + x = x.view(x.shape[0], -1) + x = self.output(x) + return torch.nn.functional.softmax(x, dim=1) + + +class DummyVIT(torch.nn.Module): + def __init__(self, num_classes: int = 2): + super().__init__() + self.num_classes = num_classes + self.feature = torch.nn.Identity() + self.output = torch.nn.LazyLinear(out_features=num_classes) + + def forward(self, x: torch.Tensor): + b, c, h, w = x.shape + x = x.reshape(b, c, h * w) + x = x.transpose(1, 2) + x = torch.cat([torch.rand((b, 1, c)), x], dim=1) + x = self.feature(x) + x = self.output(x[:, 0]) + return torch.nn.functional.softmax(x, dim=1) + + +def test_torch_method(): + model = DummyCNN() + method = TorchWhiteBoxMethod(model=model, target_layer="feature") + model_xai = method.prepare_model() + assert has_xai(model_xai) + data = np.zeros((1, 3, 5, 5)) + output = method.model_forward(data) + assert type(output) == dict + assert SALIENCY_MAP_OUTPUT_NAME in output + + class DummyMethod(TorchWhiteBoxMethod): + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: + output = torch.cat((output, output), dim=0) + return super()._feature_hook(module, inputs, output) + + def _output_hook( + self, module: torch.nn.Module, inputs: Any, output: torch.Tensor + ) -> Dict[str, torch.Tensor | None]: + return { + "prediction": output[0:], + SALIENCY_MAP_OUTPUT_NAME: output[1:], + } + + model = DummyCNN() + method = DummyMethod(model=model, target_layer="feature") + model_xai = method.prepare_model() + assert has_xai(model_xai) + data = np.random.rand(1, 3, 5, 5) + output = method.model_forward(data) + assert type(output) == dict + prediction = output["prediction"] + saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME] + assert np.all(saliency_maps == prediction) + + +def test_prepare_model(): + model = DummyCNN() + method = TorchWhiteBoxMethod(model=model, target_layer="feature") + model_xai = method.prepare_model(load_model=False) + assert method._model_compiled is None + assert model is not model_xai + + model_xai = method.prepare_model(load_model=True) + assert method._model_compiled is not None + assert model is not model_xai + + model.has_xai = True + method = TorchWhiteBoxMethod(model=model, target_layer="feature") + model_xai = method.prepare_model(load_model=False) + assert model_xai == model + + +def test_activationmap() -> None: + batch_size = 2 + num_classes = 3 + model = DummyCNN(num_classes=num_classes) + method = TorchActivationMap(model=model, target_layer="feature") + model_xai = method.prepare_model() + assert has_xai(model_xai) + data = np.random.rand(batch_size, 3, 5, 5) + output = method.model_forward(data) + assert type(output) == dict + saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME] + assert saliency_maps.shape == torch.Size([batch_size, 5, 5]) + assert np.all(saliency_maps >= 0) + assert np.all(saliency_maps <= 255) + assert saliency_maps.dtype == np.uint8 + + +@pytest.mark.parametrize("optimize_gap", [True, False]) +def test_reciprocam(optimize_gap: bool) -> None: + batch_size = 2 + num_classes = 3 + model = DummyCNN(num_classes=num_classes) + method = TorchReciproCAM(model=model, target_layer="feature", optimize_gap=optimize_gap) + model_xai = method.prepare_model() + assert has_xai(model_xai) + data = np.random.rand(batch_size, 4, 5, 5) + output = method.model_forward(data) + assert type(output) == dict + saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME] + assert saliency_maps.shape == torch.Size([batch_size, num_classes, 5, 5]) + assert np.all(saliency_maps >= 0) + assert np.all(saliency_maps <= 255) + assert saliency_maps.dtype == np.uint8 + + +@pytest.mark.parametrize("use_gaussian", [True, False]) +@pytest.mark.parametrize("use_cls_token", [True, False]) +def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None: + batch_size = 2 + num_classes = 3 + model = DummyVIT(num_classes=num_classes) + method = TorchViTReciproCAM( + model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token + ) + model_xai = method.prepare_model() + assert has_xai(model_xai) + data = np.random.rand(batch_size, 4, 5, 5) + output = method.model_forward(data) + assert type(output) == dict + saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME] + assert saliency_maps.shape == torch.Size([batch_size, num_classes, 5, 5]) + assert np.all(saliency_maps >= 0) + assert np.all(saliency_maps <= 255) + assert saliency_maps.dtype == np.uint8 diff --git a/tox.ini b/tox.ini index b656e6bd..218ec2c0 100644 --- a/tox.ini +++ b/tox.ini @@ -21,11 +21,6 @@ extras = dev commands = pytest -ra --showlocals {posargs:tests/} -[testenv:val-{py310, py311}] -extras = dev,val -commands = - pytest -ra --showlocals {posargs:tests/} - [testenv:fuzz-{py310, py311}] deps = atheris