From 8d6ebe38f49e08caade775dce2c099750c379ade Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Fri, 6 Sep 2024 14:04:22 +0900 Subject: [PATCH] Add basic 4D feature map layer detection --- openvino_xai/methods/white_box/torch.py | 57 +++++++++++++++++----- tests/unit/methods/white_box/test_torch.py | 20 ++++++-- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index 29c72860..51b1f79a 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -55,11 +55,17 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module: return self._model model = copy.deepcopy(self._model) + # Feature - feature_module = self._find_feature_module(model, self._target_layer) - feature_module.register_forward_hook(self._feature_hook) + if self._target_layer: + feature_module = self._find_module_by_name(model, self._target_layer) + feature_module.register_forward_hook(self._feature_hook) + else: + self._detect_hook_handle = model.register_forward_pre_hook(self._lazy_detect_hook, prepend=True) + # Output model.register_forward_hook(self._output_hook) + setattr(model, "has_xai", True) model.eval() @@ -86,9 +92,8 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: output[name] = data.numpy(force=True) return output - def _find_feature_module(self, model: torch.nn.Module, target_name: str | None): - if target_name is None: - raise ValueError("Target layer name should be specified") + def _find_module_by_name(self, model: torch.nn.Module, target_name: str) -> torch.nn.Module: + """Search layer by name sub string match.""" target_module = None for name, module in model.named_modules(): if target_name in name: @@ -98,10 +103,37 @@ def _find_feature_module(self, model: torch.nn.Module, target_name: str | None): return target_module def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: + """Maipulate feature map for saliency map generation.""" self._feature_map = output return output + def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any: + """Detect feature module in the first foward pass and register feature hook.""" + # Make sure this hook called only 1 time + if detect_hook_handle := getattr(self, "_detect_hook_handle", None): + detect_hook_handle.remove() + delattr(self, "_detect_hook_handle") + + # Find the last layer that outputs 4D tensor during temp forward pass + self._feature_module = None + def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: + if isinstance(output, torch.Tensor): + shape = output.shape + if len(shape) == 4 and shape[2] > 1 and shape[3] > 1: + self._feature_module = module + global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) + try: + module.forward(*inputs) + finally: + global_hook_handle.remove() + if self._feature_module is None: + raise RuntimeError(f"Feature module with 4D output not found in the torch model") + + # Set feature hook + self._feature_module.register_forward_hook(self._feature_hook) + def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: + """Split combined output B0xC into BxC precition and BxCxHxW saliency map.""" return { "prediction": output, SALIENCY_MAP_OUTPUT_NAME: torch.empty_like(output), @@ -164,16 +196,17 @@ def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tens return torch.cat(feature_maps) def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: - batch_size, _, h, w = self._feature_shape - num_classes = output.shape[1] - predictions = output[:batch_size] - saliency_maps = output[batch_size:] - saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) - saliency_maps = saliency_maps.transpose(1, 2) # BxHWxC -> BxCxHW + """Split combined output B0xC into BxC precition and BxCxHxW saliency map.""" + batch_size, _, h, w = self._feature_shape # B0xDxHxW + num_classes = output.shape[1] # C + predictions = output[:batch_size] # BxC + saliency_maps = output[batch_size:] # BHWxC + saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) # BxHWxC + saliency_maps = saliency_maps.transpose(1, 2) # BxCxHW if self._embed_scaling: saliency_maps = saliency_maps.reshape((batch_size * num_classes, h * w)) saliency_maps = self._normalize_map(saliency_maps) - saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) + saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) # BxCxHxW return { "prediction": predictions, SALIENCY_MAP_OUTPUT_NAME: saliency_maps, diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index 87bf4fae..9aa9084a 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -53,13 +53,17 @@ def __init__(self, num_classes: int = 2): self.num_classes = num_classes self.feature = torch.nn.Identity() self.output = torch.nn.LazyLinear(out_features=num_classes) + self.norm = None def forward(self, x: torch.Tensor): b, c, h, w = x.shape + if not self.norm: + self.norm = torch.nn.LayerNorm(c) x = x.reshape(b, c, h * w) x = x.transpose(1, 2) x = torch.cat([torch.rand((b, 1, c)), x], dim=1) x = self.feature(x) + x = x + self.norm(x) x = self.output(x[:, 0]) return torch.nn.functional.softmax(x, dim=1) @@ -67,9 +71,6 @@ def forward(self, x: torch.Tensor): def test_torch_method(): model = DummyCNN() - with pytest.raises(ValueError): - method = TorchWhiteBoxMethod(model=model, target_layer=None) - model_xai = method.prepare_model() with pytest.raises(ValueError): method = TorchWhiteBoxMethod(model=model, target_layer="something_else") model_xai = method.prepare_model() @@ -124,6 +125,19 @@ def test_prepare_model(): assert model_xai == model +def test_lazy_detect_feature_layer(): + model = DummyCNN() + method = TorchWhiteBoxMethod(model=model, target_layer=None) + model_xai = method.prepare_model() + assert hasattr(method, "_detect_hook_handle") + assert has_xai(model_xai) + data = np.random.rand(1, 3, 5, 5) + output = method.model_forward(data) + assert not hasattr(method, "_detect_hook_handle") + assert type(output) == dict + assert method._feature_module is model_xai.feature + + def test_activationmap() -> None: batch_size = 2 num_classes = 3