diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index 0bbe8f3c..801a1715 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -41,27 +41,33 @@ def __init__( target_layer: str | None = None, embed_scaling: bool = True, device_name: str = "CPU", + prepare_model: bool = True, **kwargs, ): super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name) self._target_layer = target_layer self._embed_scaling = embed_scaling + if prepare_model: + self.prepare_model() + def prepare_model(self, load_model: bool = True) -> torch.nn.Module: """Return XAI inserted model.""" if has_xai(self._model): if load_model: self._model_compiled = self._model return self._model + if self._model_compiled is not None: + return self._model_compiled model = copy.deepcopy(self._model) # Feature if self._target_layer: feature_module = self._find_module_by_name(model, self._target_layer) - feature_module.register_forward_hook(self._feature_hook) else: - self._detect_hook_handle = model.register_forward_pre_hook(self._lazy_detect_hook, prepend=True) + feature_module = self._find_feature_module(model) + feature_module.register_forward_hook(self._feature_hook) # Output model.register_forward_hook(self._output_hook) @@ -102,18 +108,8 @@ def _find_module_by_name(self, model: torch.nn.Module, target_name: str) -> torc raise ValueError(f"{target_name} not found in the torch model") return target_module - def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: - """Maipulate feature map for saliency map generation.""" - self._feature_map = output - return output - - def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: - """Detect feature module in the first foward pass and register feature hook.""" - # Make sure this hook called only 1 time - if detect_hook_handle := getattr(self, "_detect_hook_handle", None): - detect_hook_handle.remove() - delattr(self, "_detect_hook_handle") - + def _find_feature_module(self, module: torch.nn.Module) -> torch.nn.Module: + """Detect feature module in the model.""" # Find the last layer that outputs 4D tensor during temp forward pass self._feature_module = None self._num_modules = 0 @@ -128,7 +124,7 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) try: - module.forward(*inputs) + module.forward(torch.zeros((1, 3, 128, 128))) finally: global_hook_handle.remove() if self._feature_module is None: @@ -138,8 +134,12 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%" ) - # Set feature hook - self._feature_module.register_forward_hook(self._feature_hook) + return self._feature_module + + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: + """Maipulate feature map for saliency map generation.""" + self._feature_map = output + return output def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: """Split combined output B0xC into BxC precition and BxCxHxW saliency map.""" @@ -189,8 +189,8 @@ class TorchReciproCAM(TorchWhiteBoxMethod): """ def __init__(self, *args, optimize_gap: bool = False, **kwargs): - super().__init__(*args, **kwargs) self._optimize_gap = optimize_gap + super().__init__(*args, **kwargs) def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" @@ -262,34 +262,24 @@ def __init__( normalize: bool = True, **kwargs, ) -> None: - super().__init__(*args, **kwargs) self._use_gaussian = use_gaussian self._use_cls_token = use_cls_token + super().__init__(*args, **kwargs) - def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: - """Detect feature module in the first foward pass and register feature hook.""" - # Make sure this hook called only 1 time - if detect_hook_handle := getattr(self, "_detect_hook_handle", None): - detect_hook_handle.remove() - delattr(self, "_detect_hook_handle") - - # Find the 3rd last LayerNorm module during temp forward pass - self._feature_modules: list[torch.nn.Module] = [] - - def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: - if isinstance(module, torch.nn.LayerNorm): - self._feature_modules.append(module) + def _find_feature_module(self, module: torch.nn.Module) -> torch.nn.Module: + """Detect feature module in the model.""" + # Find the 3rd last LayerNorm module + self._feature_module = None + feature_modules: list[torch.nn.Module] = [] + for _, submodule in module.named_modules(): + if isinstance(submodule, torch.nn.LayerNorm): + feature_modules.append(submodule) - global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) - try: - module.forward(*inputs) - finally: - global_hook_handle.remove() - if len(self._feature_modules) < 3: + if len(feature_modules) < 3: raise RuntimeError("Feature modules with LayerNorm is less than 3 in the torch model") - # Set feature hook - self._feature_modules[-3].register_forward_hook(self._feature_hook) + self._feature_module = feature_modules[-3] + return self._feature_module def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" diff --git a/tests/intg/test_classification_timm.py b/tests/intg/test_classification_timm.py index d7667d61..5ce553b6 100644 --- a/tests/intg/test_classification_timm.py +++ b/tests/intg/test_classification_timm.py @@ -454,7 +454,8 @@ def test_model_format(self, model_id, explain_mode, model_format): "deit_tiny_patch16_224.fb_in1k", ], ) - def test_torch_insert_xai_with_layer(self, model_id: str): + @pytest.mark.parametrize("detect", ["auto", "name"]) + def test_torch_insert_xai_with_layer(self, model_id: str, detect: str): xai_cfg = { "resnet18.a1_in1k": ("layer4", Method.RECIPROCAM), "efficientnet_b0.ra_in1k": ("bn2", Method.RECIPROCAM), @@ -465,6 +466,9 @@ def test_torch_insert_xai_with_layer(self, model_id: str): model_dir = self.data_dir / "timm_models" / "converted_models" model, model_cfg = self.get_timm_model(model_id, model_dir) + target_layer = xai_cfg[model_id][0] if detect == "name" else None + explain_method = xai_cfg[model_id][1] + image = cv2.imread("tests/assets/cheetah_person.jpg") image = cv2.resize(image, dsize=model_cfg["input_size"][1:]) image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB) @@ -478,8 +482,8 @@ def test_torch_insert_xai_with_layer(self, model_id: str): xai_model: torch.nn.Module = insert_xai( model, task=Task.CLASSIFICATION, - target_layer=xai_cfg[model_id][0], - explain_method=xai_cfg[model_id][1], + target_layer=target_layer, + explain_method=explain_method, ) with torch.no_grad(): diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index e833fee3..8bea7d4d 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -53,9 +53,10 @@ def forward(self, x: torch.Tensor): class DummyVIT(torch.nn.Module): - def __init__(self, num_classes: int = 2): + def __init__(self, num_classes: int = 2, dim: int = 3): super().__init__() self.num_classes = num_classes + self.dim = dim self.pre = torch.nn.Sequential( torch.nn.Identity(), torch.nn.Identity(), @@ -66,15 +67,13 @@ def __init__(self, num_classes: int = 2): torch.nn.Identity(), torch.nn.Identity(), ) + self.norm1 = torch.nn.LayerNorm(dim) + self.norm2 = torch.nn.LayerNorm(dim) + self.norm3 = torch.nn.LayerNorm(dim) self.output = torch.nn.LazyLinear(out_features=num_classes) - self.norm1 = None def forward(self, x: torch.Tensor): b, c, h, w = x.shape - if not self.norm1: - self.norm1 = torch.nn.LayerNorm(c) - self.norm2 = torch.nn.LayerNorm(c) - self.norm3 = torch.nn.LayerNorm(c) x = self.pre(x) x = x.reshape(b, c, h * w) x = x.transpose(1, 2) @@ -129,7 +128,9 @@ def _output_hook( def test_prepare_model(): model = DummyCNN() - method = TorchWhiteBoxMethod(model=model, target_layer="feature") + method = TorchWhiteBoxMethod(model=model, target_layer="feature", prepare_model=False) + model_xai = method.prepare_model(load_model=False) + assert method._model_compiled is None model_xai = method.prepare_model(load_model=False) assert method._model_compiled is None assert model is not model_xai @@ -144,40 +145,31 @@ def test_prepare_model(): assert model_xai == model -def test_lazy_detect_feature_layer(): +def test_detect_feature_layer(): model = DummyCNN() method = TorchWhiteBoxMethod(model=model, target_layer=None) model_xai = method.prepare_model() - assert hasattr(method, "_detect_hook_handle") assert has_xai(model_xai) data = np.random.rand(1, 3, 5, 5) output = method.model_forward(data) - assert not hasattr(method, "_detect_hook_handle") assert type(output) == dict assert method._feature_module is model_xai.feature output = method.model_forward(data) assert type(output) == dict # still good for 2nd forward model = DummyVIT() - method = TorchWhiteBoxMethod(model=model, target_layer=None) - model_xai = method.prepare_model() - assert hasattr(method, "_detect_hook_handle") - assert has_xai(model_xai) - data = np.random.rand(1, 3, 5, 5) with pytest.raises(RuntimeError): # 4D feature map search should fail for ViTs - output = method.model_forward(data) + method = TorchWhiteBoxMethod(model=model, target_layer=None) model = DummyVIT() method = TorchViTReciproCAM(model=model, target_layer=None) model_xai = method.prepare_model() - assert hasattr(method, "_detect_hook_handle") assert has_xai(model_xai) data = np.random.rand(1, 3, 5, 5) output = method.model_forward(data) - assert not hasattr(method, "_detect_hook_handle") assert type(output) == dict - assert method._feature_modules[-3] is model_xai.norm1 + assert method._feature_module is model_xai.norm1 output = method.model_forward(data) assert type(output) == dict # still good for 2nd forward @@ -222,13 +214,14 @@ def test_reciprocam(optimize_gap: bool) -> None: def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None: batch_size = 2 num_classes = 3 - model = DummyVIT(num_classes=num_classes) + dim = 3 + model = DummyVIT(num_classes=num_classes, dim=dim) method = TorchViTReciproCAM( model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token ) model_xai = method.prepare_model() assert has_xai(model_xai) - data = np.random.rand(batch_size, 4, 5, 5) + data = np.random.rand(batch_size, dim, 5, 5) output = method.model_forward(data) assert type(output) == dict saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME]