From 625c71c80d14f80568c61bfda76fa1d728bfed2c Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Wed, 4 Sep 2024 15:42:41 +0900 Subject: [PATCH 1/5] Support basic sub-string match --- openvino_xai/methods/white_box/torch.py | 15 +++++++++++++-- tests/unit/methods/white_box/test_torch.py | 8 ++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index a1e8e80f..29c72860 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -56,8 +56,8 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module: model = copy.deepcopy(self._model) # Feature - feature_layer = model.get_submodule(self._target_layer) - feature_layer.register_forward_hook(self._feature_hook) + feature_module = self._find_feature_module(model, self._target_layer) + feature_module.register_forward_hook(self._feature_hook) # Output model.register_forward_hook(self._output_hook) setattr(model, "has_xai", True) @@ -86,6 +86,17 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: output[name] = data.numpy(force=True) return output + def _find_feature_module(self, model: torch.nn.Module, target_name: str | None): + if target_name is None: + raise ValueError("Target layer name should be specified") + target_module = None + for name, module in model.named_modules(): + if target_name in name: + target_module = module + if target_module is None: + raise ValueError(f"{target_name} not found in the torch model") + return target_module + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: self._feature_map = output return output diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index fa374f6c..87bf4fae 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -66,6 +66,14 @@ def forward(self, x: torch.Tensor): def test_torch_method(): model = DummyCNN() + + with pytest.raises(ValueError): + method = TorchWhiteBoxMethod(model=model, target_layer=None) + model_xai = method.prepare_model() + with pytest.raises(ValueError): + method = TorchWhiteBoxMethod(model=model, target_layer="something_else") + model_xai = method.prepare_model() + method = TorchWhiteBoxMethod(model=model, target_layer="feature") model_xai = method.prepare_model() assert has_xai(model_xai) From 8d6ebe38f49e08caade775dce2c099750c379ade Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Fri, 6 Sep 2024 14:04:22 +0900 Subject: [PATCH 2/5] Add basic 4D feature map layer detection --- openvino_xai/methods/white_box/torch.py | 57 +++++++++++++++++----- tests/unit/methods/white_box/test_torch.py | 20 ++++++-- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index 29c72860..51b1f79a 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -55,11 +55,17 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module: return self._model model = copy.deepcopy(self._model) + # Feature - feature_module = self._find_feature_module(model, self._target_layer) - feature_module.register_forward_hook(self._feature_hook) + if self._target_layer: + feature_module = self._find_module_by_name(model, self._target_layer) + feature_module.register_forward_hook(self._feature_hook) + else: + self._detect_hook_handle = model.register_forward_pre_hook(self._lazy_detect_hook, prepend=True) + # Output model.register_forward_hook(self._output_hook) + setattr(model, "has_xai", True) model.eval() @@ -86,9 +92,8 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: output[name] = data.numpy(force=True) return output - def _find_feature_module(self, model: torch.nn.Module, target_name: str | None): - if target_name is None: - raise ValueError("Target layer name should be specified") + def _find_module_by_name(self, model: torch.nn.Module, target_name: str) -> torch.nn.Module: + """Search layer by name sub string match.""" target_module = None for name, module in model.named_modules(): if target_name in name: @@ -98,10 +103,37 @@ def _find_feature_module(self, model: torch.nn.Module, target_name: str | None): return target_module def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: + """Maipulate feature map for saliency map generation.""" self._feature_map = output return output + def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: + """Detect feature module in the first foward pass and register feature hook.""" + # Make sure this hook called only 1 time + if detect_hook_handle := getattr(self, "_detect_hook_handle", None): + detect_hook_handle.remove() + delattr(self, "_detect_hook_handle") + + # Find the last layer that outputs 4D tensor during temp forward pass + self._feature_module = None + def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: + if isinstance(output, torch.Tensor): + shape = output.shape + if len(shape) == 4 and shape[2] > 1 and shape[3] > 1: + self._feature_module = module + global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) + try: + module.forward(*inputs) + finally: + global_hook_handle.remove() + if self._feature_module is None: + raise RuntimeError(f"Feature module with 4D output not found in the torch model") + + # Set feature hook + self._feature_module.register_forward_hook(self._feature_hook) + def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: + """Split combined output B0xC into BxC precition and BxCxHxW saliency map.""" return { "prediction": output, SALIENCY_MAP_OUTPUT_NAME: torch.empty_like(output), @@ -164,16 +196,17 @@ def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tens return torch.cat(feature_maps) def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: - batch_size, _, h, w = self._feature_shape - num_classes = output.shape[1] - predictions = output[:batch_size] - saliency_maps = output[batch_size:] - saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) - saliency_maps = saliency_maps.transpose(1, 2) # BxHWxC -> BxCxHW + """Split combined output B0xC into BxC precition and BxCxHxW saliency map.""" + batch_size, _, h, w = self._feature_shape # B0xDxHxW + num_classes = output.shape[1] # C + predictions = output[:batch_size] # BxC + saliency_maps = output[batch_size:] # BHWxC + saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) # BxHWxC + saliency_maps = saliency_maps.transpose(1, 2) # BxCxHW if self._embed_scaling: saliency_maps = saliency_maps.reshape((batch_size * num_classes, h * w)) saliency_maps = self._normalize_map(saliency_maps) - saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) + saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) # BxCxHxW return { "prediction": predictions, SALIENCY_MAP_OUTPUT_NAME: saliency_maps, diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index 87bf4fae..9aa9084a 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -53,13 +53,17 @@ def __init__(self, num_classes: int = 2): self.num_classes = num_classes self.feature = torch.nn.Identity() self.output = torch.nn.LazyLinear(out_features=num_classes) + self.norm = None def forward(self, x: torch.Tensor): b, c, h, w = x.shape + if not self.norm: + self.norm = torch.nn.LayerNorm(c) x = x.reshape(b, c, h * w) x = x.transpose(1, 2) x = torch.cat([torch.rand((b, 1, c)), x], dim=1) x = self.feature(x) + x = x + self.norm(x) x = self.output(x[:, 0]) return torch.nn.functional.softmax(x, dim=1) @@ -67,9 +71,6 @@ def forward(self, x: torch.Tensor): def test_torch_method(): model = DummyCNN() - with pytest.raises(ValueError): - method = TorchWhiteBoxMethod(model=model, target_layer=None) - model_xai = method.prepare_model() with pytest.raises(ValueError): method = TorchWhiteBoxMethod(model=model, target_layer="something_else") model_xai = method.prepare_model() @@ -124,6 +125,19 @@ def test_prepare_model(): assert model_xai == model +def test_lazy_detect_feature_layer(): + model = DummyCNN() + method = TorchWhiteBoxMethod(model=model, target_layer=None) + model_xai = method.prepare_model() + assert hasattr(method, "_detect_hook_handle") + assert has_xai(model_xai) + data = np.random.rand(1, 3, 5, 5) + output = method.model_forward(data) + assert not hasattr(method, "_detect_hook_handle") + assert type(output) == dict + assert method._feature_module is model_xai.feature + + def test_activationmap() -> None: batch_size = 2 num_classes = 3 From 0603eea7312e3da1a549e39a854e2fc0c3819ae6 Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Fri, 6 Sep 2024 16:34:29 +0900 Subject: [PATCH 3/5] Add N-last LayerNorm module detection for ViTs --- openvino_xai/methods/white_box/torch.py | 36 +++++++++++++- tests/unit/methods/white_box/test_torch.py | 56 +++++++++++++++++++--- 2 files changed, 85 insertions(+), 7 deletions(-) diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index 51b1f79a..0bbe8f3c 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -116,18 +116,27 @@ def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: # Find the last layer that outputs 4D tensor during temp forward pass self._feature_module = None + self._num_modules = 0 + def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: if isinstance(output, torch.Tensor): + module.index = self._num_modules + self._num_modules += 1 shape = output.shape if len(shape) == 4 and shape[2] > 1 and shape[3] > 1: self._feature_module = module + global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) try: module.forward(*inputs) finally: global_hook_handle.remove() if self._feature_module is None: - raise RuntimeError(f"Feature module with 4D output not found in the torch model") + raise RuntimeError("Feature module with 4D output not found in the torch model") + if self._feature_module.index / self._num_modules < 0.5: # Check if ViT-like architectures + raise RuntimeError( + f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%" + ) # Set feature hook self._feature_module.register_forward_hook(self._feature_hook) @@ -257,6 +266,31 @@ def __init__( self._use_gaussian = use_gaussian self._use_cls_token = use_cls_token + def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: + """Detect feature module in the first foward pass and register feature hook.""" + # Make sure this hook called only 1 time + if detect_hook_handle := getattr(self, "_detect_hook_handle", None): + detect_hook_handle.remove() + delattr(self, "_detect_hook_handle") + + # Find the 3rd last LayerNorm module during temp forward pass + self._feature_modules: list[torch.nn.Module] = [] + + def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: + if isinstance(module, torch.nn.LayerNorm): + self._feature_modules.append(module) + + global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) + try: + module.forward(*inputs) + finally: + global_hook_handle.remove() + if len(self._feature_modules) < 3: + raise RuntimeError("Feature modules with LayerNorm is less than 3 in the torch model") + + # Set feature hook + self._feature_modules[-3].register_forward_hook(self._feature_hook) + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" feature_map = output diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index 9aa9084a..e833fee3 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -35,7 +35,12 @@ class DummyCNN(torch.nn.Module): def __init__(self, num_classes: int = 2): super().__init__() self.num_classes = num_classes - self.feature = torch.nn.Identity() + self.feature = torch.nn.Sequential( + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + ) self.neck = torch.nn.AdaptiveAvgPool2d((1, 1)) self.output = torch.nn.LazyLinear(out_features=num_classes) @@ -51,19 +56,33 @@ class DummyVIT(torch.nn.Module): def __init__(self, num_classes: int = 2): super().__init__() self.num_classes = num_classes - self.feature = torch.nn.Identity() + self.pre = torch.nn.Sequential( + torch.nn.Identity(), + torch.nn.Identity(), + ) + self.feature = torch.nn.Sequential( + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + ) self.output = torch.nn.LazyLinear(out_features=num_classes) - self.norm = None + self.norm1 = None def forward(self, x: torch.Tensor): b, c, h, w = x.shape - if not self.norm: - self.norm = torch.nn.LayerNorm(c) + if not self.norm1: + self.norm1 = torch.nn.LayerNorm(c) + self.norm2 = torch.nn.LayerNorm(c) + self.norm3 = torch.nn.LayerNorm(c) + x = self.pre(x) x = x.reshape(b, c, h * w) x = x.transpose(1, 2) x = torch.cat([torch.rand((b, 1, c)), x], dim=1) x = self.feature(x) - x = x + self.norm(x) + x = x + self.norm1(x) + x = x + self.norm2(x) + x = x + self.norm3(x) x = self.output(x[:, 0]) return torch.nn.functional.softmax(x, dim=1) @@ -136,6 +155,31 @@ def test_lazy_detect_feature_layer(): assert not hasattr(method, "_detect_hook_handle") assert type(output) == dict assert method._feature_module is model_xai.feature + output = method.model_forward(data) + assert type(output) == dict # still good for 2nd forward + + model = DummyVIT() + method = TorchWhiteBoxMethod(model=model, target_layer=None) + model_xai = method.prepare_model() + assert hasattr(method, "_detect_hook_handle") + assert has_xai(model_xai) + data = np.random.rand(1, 3, 5, 5) + with pytest.raises(RuntimeError): + # 4D feature map search should fail for ViTs + output = method.model_forward(data) + + model = DummyVIT() + method = TorchViTReciproCAM(model=model, target_layer=None) + model_xai = method.prepare_model() + assert hasattr(method, "_detect_hook_handle") + assert has_xai(model_xai) + data = np.random.rand(1, 3, 5, 5) + output = method.model_forward(data) + assert not hasattr(method, "_detect_hook_handle") + assert type(output) == dict + assert method._feature_modules[-3] is model_xai.norm1 + output = method.model_forward(data) + assert type(output) == dict # still good for 2nd forward def test_activationmap() -> None: From 2e97454f7bc6e7b5521e71280f1f61df17846bbd Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Mon, 9 Sep 2024 16:22:48 +0900 Subject: [PATCH 4/5] Add integration test for auto layer detection --- openvino_xai/methods/white_box/torch.py | 70 ++++++++++------------ tests/intg/test_classification_timm.py | 10 +++- tests/unit/methods/test_factory.py | 6 +- tests/unit/methods/white_box/test_torch.py | 35 +++++------ 4 files changed, 56 insertions(+), 65 deletions(-) diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index 0bbe8f3c..801a1715 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -41,27 +41,33 @@ def __init__( target_layer: str | None = None, embed_scaling: bool = True, device_name: str = "CPU", + prepare_model: bool = True, **kwargs, ): super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name) self._target_layer = target_layer self._embed_scaling = embed_scaling + if prepare_model: + self.prepare_model() + def prepare_model(self, load_model: bool = True) -> torch.nn.Module: """Return XAI inserted model.""" if has_xai(self._model): if load_model: self._model_compiled = self._model return self._model + if self._model_compiled is not None: + return self._model_compiled model = copy.deepcopy(self._model) # Feature if self._target_layer: feature_module = self._find_module_by_name(model, self._target_layer) - feature_module.register_forward_hook(self._feature_hook) else: - self._detect_hook_handle = model.register_forward_pre_hook(self._lazy_detect_hook, prepend=True) + feature_module = self._find_feature_module(model) + feature_module.register_forward_hook(self._feature_hook) # Output model.register_forward_hook(self._output_hook) @@ -102,18 +108,8 @@ def _find_module_by_name(self, model: torch.nn.Module, target_name: str) -> torc raise ValueError(f"{target_name} not found in the torch model") return target_module - def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: - """Maipulate feature map for saliency map generation.""" - self._feature_map = output - return output - - def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: - """Detect feature module in the first foward pass and register feature hook.""" - # Make sure this hook called only 1 time - if detect_hook_handle := getattr(self, "_detect_hook_handle", None): - detect_hook_handle.remove() - delattr(self, "_detect_hook_handle") - + def _find_feature_module(self, module: torch.nn.Module) -> torch.nn.Module: + """Detect feature module in the model.""" # Find the last layer that outputs 4D tensor during temp forward pass self._feature_module = None self._num_modules = 0 @@ -128,7 +124,7 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) try: - module.forward(*inputs) + module.forward(torch.zeros((1, 3, 128, 128))) finally: global_hook_handle.remove() if self._feature_module is None: @@ -138,8 +134,12 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%" ) - # Set feature hook - self._feature_module.register_forward_hook(self._feature_hook) + return self._feature_module + + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: + """Maipulate feature map for saliency map generation.""" + self._feature_map = output + return output def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: """Split combined output B0xC into BxC precition and BxCxHxW saliency map.""" @@ -189,8 +189,8 @@ class TorchReciproCAM(TorchWhiteBoxMethod): """ def __init__(self, *args, optimize_gap: bool = False, **kwargs): - super().__init__(*args, **kwargs) self._optimize_gap = optimize_gap + super().__init__(*args, **kwargs) def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" @@ -262,34 +262,24 @@ def __init__( normalize: bool = True, **kwargs, ) -> None: - super().__init__(*args, **kwargs) self._use_gaussian = use_gaussian self._use_cls_token = use_cls_token + super().__init__(*args, **kwargs) - def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: - """Detect feature module in the first foward pass and register feature hook.""" - # Make sure this hook called only 1 time - if detect_hook_handle := getattr(self, "_detect_hook_handle", None): - detect_hook_handle.remove() - delattr(self, "_detect_hook_handle") - - # Find the 3rd last LayerNorm module during temp forward pass - self._feature_modules: list[torch.nn.Module] = [] - - def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: - if isinstance(module, torch.nn.LayerNorm): - self._feature_modules.append(module) + def _find_feature_module(self, module: torch.nn.Module) -> torch.nn.Module: + """Detect feature module in the model.""" + # Find the 3rd last LayerNorm module + self._feature_module = None + feature_modules: list[torch.nn.Module] = [] + for _, submodule in module.named_modules(): + if isinstance(submodule, torch.nn.LayerNorm): + feature_modules.append(submodule) - global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) - try: - module.forward(*inputs) - finally: - global_hook_handle.remove() - if len(self._feature_modules) < 3: + if len(feature_modules) < 3: raise RuntimeError("Feature modules with LayerNorm is less than 3 in the torch model") - # Set feature hook - self._feature_modules[-3].register_forward_hook(self._feature_hook) + self._feature_module = feature_modules[-3] + return self._feature_module def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" diff --git a/tests/intg/test_classification_timm.py b/tests/intg/test_classification_timm.py index d7667d61..5ce553b6 100644 --- a/tests/intg/test_classification_timm.py +++ b/tests/intg/test_classification_timm.py @@ -454,7 +454,8 @@ def test_model_format(self, model_id, explain_mode, model_format): "deit_tiny_patch16_224.fb_in1k", ], ) - def test_torch_insert_xai_with_layer(self, model_id: str): + @pytest.mark.parametrize("detect", ["auto", "name"]) + def test_torch_insert_xai_with_layer(self, model_id: str, detect: str): xai_cfg = { "resnet18.a1_in1k": ("layer4", Method.RECIPROCAM), "efficientnet_b0.ra_in1k": ("bn2", Method.RECIPROCAM), @@ -465,6 +466,9 @@ def test_torch_insert_xai_with_layer(self, model_id: str): model_dir = self.data_dir / "timm_models" / "converted_models" model, model_cfg = self.get_timm_model(model_id, model_dir) + target_layer = xai_cfg[model_id][0] if detect == "name" else None + explain_method = xai_cfg[model_id][1] + image = cv2.imread("tests/assets/cheetah_person.jpg") image = cv2.resize(image, dsize=model_cfg["input_size"][1:]) image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB) @@ -478,8 +482,8 @@ def test_torch_insert_xai_with_layer(self, model_id: str): xai_model: torch.nn.Module = insert_xai( model, task=Task.CLASSIFICATION, - target_layer=xai_cfg[model_id][0], - explain_method=xai_cfg[model_id][1], + target_layer=target_layer, + explain_method=explain_method, ) with torch.no_grad(): diff --git a/tests/unit/methods/test_factory.py b/tests/unit/methods/test_factory.py index dac63211..0a0b24d8 100644 --- a/tests/unit/methods/test_factory.py +++ b/tests/unit/methods/test_factory.py @@ -151,7 +151,7 @@ def test_create_wb_det_cnn_method(fxt_data_root: Path): assert str(exc_info.value) == "Requested explanation method abc is not implemented." -def test_create_torch_method(): +def test_create_torch_method(mocker: MockerFixture): model = {} with pytest.raises(ValueError): explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn()) @@ -172,6 +172,10 @@ def test_create_torch_method(): Task.DETECTION, model, get_postprocess_fn(), target_layer="" ) + mocker.patch.object(torch_method.TorchActivationMap, "prepare_model") + mocker.patch.object(torch_method.TorchReciproCAM, "prepare_model") + mocker.patch.object(torch_method.TorchViTReciproCAM, "prepare_model") + model = torch.nn.Module() explain_method = WhiteBoxMethodFactory.create_method( Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.ACTIVATIONMAP diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index e833fee3..8bea7d4d 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -53,9 +53,10 @@ def forward(self, x: torch.Tensor): class DummyVIT(torch.nn.Module): - def __init__(self, num_classes: int = 2): + def __init__(self, num_classes: int = 2, dim: int = 3): super().__init__() self.num_classes = num_classes + self.dim = dim self.pre = torch.nn.Sequential( torch.nn.Identity(), torch.nn.Identity(), @@ -66,15 +67,13 @@ def __init__(self, num_classes: int = 2): torch.nn.Identity(), torch.nn.Identity(), ) + self.norm1 = torch.nn.LayerNorm(dim) + self.norm2 = torch.nn.LayerNorm(dim) + self.norm3 = torch.nn.LayerNorm(dim) self.output = torch.nn.LazyLinear(out_features=num_classes) - self.norm1 = None def forward(self, x: torch.Tensor): b, c, h, w = x.shape - if not self.norm1: - self.norm1 = torch.nn.LayerNorm(c) - self.norm2 = torch.nn.LayerNorm(c) - self.norm3 = torch.nn.LayerNorm(c) x = self.pre(x) x = x.reshape(b, c, h * w) x = x.transpose(1, 2) @@ -129,7 +128,9 @@ def _output_hook( def test_prepare_model(): model = DummyCNN() - method = TorchWhiteBoxMethod(model=model, target_layer="feature") + method = TorchWhiteBoxMethod(model=model, target_layer="feature", prepare_model=False) + model_xai = method.prepare_model(load_model=False) + assert method._model_compiled is None model_xai = method.prepare_model(load_model=False) assert method._model_compiled is None assert model is not model_xai @@ -144,40 +145,31 @@ def test_prepare_model(): assert model_xai == model -def test_lazy_detect_feature_layer(): +def test_detect_feature_layer(): model = DummyCNN() method = TorchWhiteBoxMethod(model=model, target_layer=None) model_xai = method.prepare_model() - assert hasattr(method, "_detect_hook_handle") assert has_xai(model_xai) data = np.random.rand(1, 3, 5, 5) output = method.model_forward(data) - assert not hasattr(method, "_detect_hook_handle") assert type(output) == dict assert method._feature_module is model_xai.feature output = method.model_forward(data) assert type(output) == dict # still good for 2nd forward model = DummyVIT() - method = TorchWhiteBoxMethod(model=model, target_layer=None) - model_xai = method.prepare_model() - assert hasattr(method, "_detect_hook_handle") - assert has_xai(model_xai) - data = np.random.rand(1, 3, 5, 5) with pytest.raises(RuntimeError): # 4D feature map search should fail for ViTs - output = method.model_forward(data) + method = TorchWhiteBoxMethod(model=model, target_layer=None) model = DummyVIT() method = TorchViTReciproCAM(model=model, target_layer=None) model_xai = method.prepare_model() - assert hasattr(method, "_detect_hook_handle") assert has_xai(model_xai) data = np.random.rand(1, 3, 5, 5) output = method.model_forward(data) - assert not hasattr(method, "_detect_hook_handle") assert type(output) == dict - assert method._feature_modules[-3] is model_xai.norm1 + assert method._feature_module is model_xai.norm1 output = method.model_forward(data) assert type(output) == dict # still good for 2nd forward @@ -222,13 +214,14 @@ def test_reciprocam(optimize_gap: bool) -> None: def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None: batch_size = 2 num_classes = 3 - model = DummyVIT(num_classes=num_classes) + dim = 3 + model = DummyVIT(num_classes=num_classes, dim=dim) method = TorchViTReciproCAM( model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token ) model_xai = method.prepare_model() assert has_xai(model_xai) - data = np.random.rand(batch_size, 4, 5, 5) + data = np.random.rand(batch_size, dim, 5, 5) output = method.model_forward(data) assert type(output) == dict saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME] From 6617ccbf73b700f4cd60646ed49d46673e16060d Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Tue, 10 Sep 2024 11:10:35 +0900 Subject: [PATCH 5/5] Apply review comments --- openvino_xai/methods/white_box/torch.py | 32 +++++++++++-------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index 801a1715..2163dca8 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -64,9 +64,9 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module: # Feature if self._target_layer: - feature_module = self._find_module_by_name(model, self._target_layer) + feature_module = self._find_feature_module_by_name(model, self._target_layer) else: - feature_module = self._find_feature_module(model) + feature_module = self._find_feature_module_auto(model) feature_module.register_forward_hook(self._feature_hook) # Output @@ -98,17 +98,17 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: output[name] = data.numpy(force=True) return output - def _find_module_by_name(self, model: torch.nn.Module, target_name: str) -> torch.nn.Module: - """Search layer by name sub string match.""" + def _find_feature_module_by_name(self, model: torch.nn.Module, target_name: str) -> torch.nn.Module: + """Search the last layer by name sub string match.""" target_module = None for name, module in model.named_modules(): if target_name in name: target_module = module if target_module is None: - raise ValueError(f"{target_name} not found in the torch model") + raise ValueError(f"{target_name} is not found in the torch model") return target_module - def _find_feature_module(self, module: torch.nn.Module) -> torch.nn.Module: + def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module: """Detect feature module in the model.""" # Find the last layer that outputs 4D tensor during temp forward pass self._feature_module = None @@ -128,7 +128,7 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: finally: global_hook_handle.remove() if self._feature_module is None: - raise RuntimeError("Feature module with 4D output not found in the torch model") + raise RuntimeError("Feature module with 4D output is not found in the torch model") if self._feature_module.index / self._num_modules < 0.5: # Check if ViT-like architectures raise RuntimeError( f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%" @@ -137,7 +137,7 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: return self._feature_module def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: - """Maipulate feature map for saliency map generation.""" + """Manipulate feature map for saliency map generation.""" self._feature_map = output return output @@ -266,19 +266,15 @@ def __init__( self._use_cls_token = use_cls_token super().__init__(*args, **kwargs) - def _find_feature_module(self, module: torch.nn.Module) -> torch.nn.Module: - """Detect feature module in the model.""" - # Find the 3rd last LayerNorm module + def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module: + """Detect feature module in the model by finding the 3rd last LayerNorm module.""" self._feature_module = None - feature_modules: list[torch.nn.Module] = [] - for _, submodule in module.named_modules(): - if isinstance(submodule, torch.nn.LayerNorm): - feature_modules.append(submodule) + norm_modules = [m for _, m in module.named_modules() if isinstance(m, torch.nn.LayerNorm)] - if len(feature_modules) < 3: - raise RuntimeError("Feature modules with LayerNorm is less than 3 in the torch model") + if len(norm_modules) < 3: + raise RuntimeError("Feature modules with LayerNorm are less than 3 in the torch model") - self._feature_module = feature_modules[-3] + self._feature_module = norm_modules[-3] return self._feature_module def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: