diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index 51b1f79a..0bbe8f3c 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -116,18 +116,27 @@ def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: # Find the last layer that outputs 4D tensor during temp forward pass self._feature_module = None + self._num_modules = 0 + def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: if isinstance(output, torch.Tensor): + module.index = self._num_modules + self._num_modules += 1 shape = output.shape if len(shape) == 4 and shape[2] > 1 and shape[3] > 1: self._feature_module = module + global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) try: module.forward(*inputs) finally: global_hook_handle.remove() if self._feature_module is None: - raise RuntimeError(f"Feature module with 4D output not found in the torch model") + raise RuntimeError("Feature module with 4D output not found in the torch model") + if self._feature_module.index / self._num_modules < 0.5: # Check if ViT-like architectures + raise RuntimeError( + f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%" + ) # Set feature hook self._feature_module.register_forward_hook(self._feature_hook) @@ -257,6 +266,31 @@ def __init__( self._use_gaussian = use_gaussian self._use_cls_token = use_cls_token + def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: + """Detect feature module in the first foward pass and register feature hook.""" + # Make sure this hook called only 1 time + if detect_hook_handle := getattr(self, "_detect_hook_handle", None): + detect_hook_handle.remove() + delattr(self, "_detect_hook_handle") + + # Find the 3rd last LayerNorm module during temp forward pass + self._feature_modules: list[torch.nn.Module] = [] + + def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: + if isinstance(module, torch.nn.LayerNorm): + self._feature_modules.append(module) + + global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) + try: + module.forward(*inputs) + finally: + global_hook_handle.remove() + if len(self._feature_modules) < 3: + raise RuntimeError("Feature modules with LayerNorm is less than 3 in the torch model") + + # Set feature hook + self._feature_modules[-3].register_forward_hook(self._feature_hook) + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" feature_map = output diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index 9aa9084a..e833fee3 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -35,7 +35,12 @@ class DummyCNN(torch.nn.Module): def __init__(self, num_classes: int = 2): super().__init__() self.num_classes = num_classes - self.feature = torch.nn.Identity() + self.feature = torch.nn.Sequential( + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + ) self.neck = torch.nn.AdaptiveAvgPool2d((1, 1)) self.output = torch.nn.LazyLinear(out_features=num_classes) @@ -51,19 +56,33 @@ class DummyVIT(torch.nn.Module): def __init__(self, num_classes: int = 2): super().__init__() self.num_classes = num_classes - self.feature = torch.nn.Identity() + self.pre = torch.nn.Sequential( + torch.nn.Identity(), + torch.nn.Identity(), + ) + self.feature = torch.nn.Sequential( + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + ) self.output = torch.nn.LazyLinear(out_features=num_classes) - self.norm = None + self.norm1 = None def forward(self, x: torch.Tensor): b, c, h, w = x.shape - if not self.norm: - self.norm = torch.nn.LayerNorm(c) + if not self.norm1: + self.norm1 = torch.nn.LayerNorm(c) + self.norm2 = torch.nn.LayerNorm(c) + self.norm3 = torch.nn.LayerNorm(c) + x = self.pre(x) x = x.reshape(b, c, h * w) x = x.transpose(1, 2) x = torch.cat([torch.rand((b, 1, c)), x], dim=1) x = self.feature(x) - x = x + self.norm(x) + x = x + self.norm1(x) + x = x + self.norm2(x) + x = x + self.norm3(x) x = self.output(x[:, 0]) return torch.nn.functional.softmax(x, dim=1) @@ -136,6 +155,31 @@ def test_lazy_detect_feature_layer(): assert not hasattr(method, "_detect_hook_handle") assert type(output) == dict assert method._feature_module is model_xai.feature + output = method.model_forward(data) + assert type(output) == dict # still good for 2nd forward + + model = DummyVIT() + method = TorchWhiteBoxMethod(model=model, target_layer=None) + model_xai = method.prepare_model() + assert hasattr(method, "_detect_hook_handle") + assert has_xai(model_xai) + data = np.random.rand(1, 3, 5, 5) + with pytest.raises(RuntimeError): + # 4D feature map search should fail for ViTs + output = method.model_forward(data) + + model = DummyVIT() + method = TorchViTReciproCAM(model=model, target_layer=None) + model_xai = method.prepare_model() + assert hasattr(method, "_detect_hook_handle") + assert has_xai(model_xai) + data = np.random.rand(1, 3, 5, 5) + output = method.model_forward(data) + assert not hasattr(method, "_detect_hook_handle") + assert type(output) == dict + assert method._feature_modules[-3] is model_xai.norm1 + output = method.model_forward(data) + assert type(output) == dict # still good for 2nd forward def test_activationmap() -> None: