Skip to content

Commit

Permalink
Add basic 4D feature map layer detection
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Sep 6, 2024
1 parent 625c71c commit 8d6ebe3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
57 changes: 45 additions & 12 deletions openvino_xai/methods/white_box/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,17 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
return self._model

model = copy.deepcopy(self._model)

# Feature
feature_module = self._find_feature_module(model, self._target_layer)
feature_module.register_forward_hook(self._feature_hook)
if self._target_layer:
feature_module = self._find_module_by_name(model, self._target_layer)
feature_module.register_forward_hook(self._feature_hook)
else:
self._detect_hook_handle = model.register_forward_pre_hook(self._lazy_detect_hook, prepend=True)

# Output
model.register_forward_hook(self._output_hook)

setattr(model, "has_xai", True)
model.eval()

Expand All @@ -86,9 +92,8 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping:
output[name] = data.numpy(force=True)
return output

def _find_feature_module(self, model: torch.nn.Module, target_name: str | None):
if target_name is None:
raise ValueError("Target layer name should be specified")
def _find_module_by_name(self, model: torch.nn.Module, target_name: str) -> torch.nn.Module:
"""Search layer by name sub string match."""
target_module = None
for name, module in model.named_modules():
if target_name in name:
Expand All @@ -98,10 +103,37 @@ def _find_feature_module(self, model: torch.nn.Module, target_name: str | None):
return target_module

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""Maipulate feature map for saliency map generation."""
self._feature_map = output
return output

def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any:
"""Detect feature module in the first foward pass and register feature hook."""
# Make sure this hook called only 1 time
if detect_hook_handle := getattr(self, "_detect_hook_handle", None):
detect_hook_handle.remove()
delattr(self, "_detect_hook_handle")

# Find the last layer that outputs 4D tensor during temp forward pass
self._feature_module = None
def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
if isinstance(output, torch.Tensor):
shape = output.shape
if len(shape) == 4 and shape[2] > 1 and shape[3] > 1:
self._feature_module = module
global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook)
try:
module.forward(*inputs)
finally:
global_hook_handle.remove()
if self._feature_module is None:
raise RuntimeError(f"Feature module with 4D output not found in the torch model")

# Set feature hook
self._feature_module.register_forward_hook(self._feature_hook)

def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
return {
"prediction": output,
SALIENCY_MAP_OUTPUT_NAME: torch.empty_like(output),
Expand Down Expand Up @@ -164,16 +196,17 @@ def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tens
return torch.cat(feature_maps)

def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
batch_size, _, h, w = self._feature_shape
num_classes = output.shape[1]
predictions = output[:batch_size]
saliency_maps = output[batch_size:]
saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes])
saliency_maps = saliency_maps.transpose(1, 2) # BxHWxC -> BxCxHW
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
batch_size, _, h, w = self._feature_shape # B0xDxHxW
num_classes = output.shape[1] # C
predictions = output[:batch_size] # BxC
saliency_maps = output[batch_size:] # BHWxC
saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) # BxHWxC
saliency_maps = saliency_maps.transpose(1, 2) # BxCxHW
if self._embed_scaling:
saliency_maps = saliency_maps.reshape((batch_size * num_classes, h * w))
saliency_maps = self._normalize_map(saliency_maps)
saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w])
saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) # BxCxHxW
return {
"prediction": predictions,
SALIENCY_MAP_OUTPUT_NAME: saliency_maps,
Expand Down
20 changes: 17 additions & 3 deletions tests/unit/methods/white_box/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,24 @@ def __init__(self, num_classes: int = 2):
self.num_classes = num_classes
self.feature = torch.nn.Identity()
self.output = torch.nn.LazyLinear(out_features=num_classes)
self.norm = None

def forward(self, x: torch.Tensor):
b, c, h, w = x.shape
if not self.norm:
self.norm = torch.nn.LayerNorm(c)
x = x.reshape(b, c, h * w)
x = x.transpose(1, 2)
x = torch.cat([torch.rand((b, 1, c)), x], dim=1)
x = self.feature(x)
x = x + self.norm(x)
x = self.output(x[:, 0])
return torch.nn.functional.softmax(x, dim=1)


def test_torch_method():
model = DummyCNN()

with pytest.raises(ValueError):
method = TorchWhiteBoxMethod(model=model, target_layer=None)
model_xai = method.prepare_model()
with pytest.raises(ValueError):
method = TorchWhiteBoxMethod(model=model, target_layer="something_else")
model_xai = method.prepare_model()
Expand Down Expand Up @@ -124,6 +125,19 @@ def test_prepare_model():
assert model_xai == model


def test_lazy_detect_feature_layer():
model = DummyCNN()
method = TorchWhiteBoxMethod(model=model, target_layer=None)
model_xai = method.prepare_model()
assert hasattr(method, "_detect_hook_handle")
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
output = method.model_forward(data)
assert not hasattr(method, "_detect_hook_handle")
assert type(output) == dict
assert method._feature_module is model_xai.feature


def test_activationmap() -> None:
batch_size = 2
num_classes = 3
Expand Down

0 comments on commit 8d6ebe3

Please sign in to comment.