Skip to content

Commit

Permalink
Add N-last LayerNorm module detection for ViTs
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Sep 6, 2024
1 parent 8d6ebe3 commit 0603eea
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 7 deletions.
36 changes: 35 additions & 1 deletion openvino_xai/methods/white_box/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,27 @@ def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any:

# Find the last layer that outputs 4D tensor during temp forward pass
self._feature_module = None
self._num_modules = 0

def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
if isinstance(output, torch.Tensor):
module.index = self._num_modules
self._num_modules += 1
shape = output.shape
if len(shape) == 4 and shape[2] > 1 and shape[3] > 1:
self._feature_module = module

global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook)
try:
module.forward(*inputs)
finally:
global_hook_handle.remove()
if self._feature_module is None:
raise RuntimeError(f"Feature module with 4D output not found in the torch model")
raise RuntimeError("Feature module with 4D output not found in the torch model")

Check warning on line 135 in openvino_xai/methods/white_box/torch.py

View check run for this annotation

Codecov / codecov/patch

openvino_xai/methods/white_box/torch.py#L135

Added line #L135 was not covered by tests
if self._feature_module.index / self._num_modules < 0.5: # Check if ViT-like architectures
raise RuntimeError(
f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%"
)

# Set feature hook
self._feature_module.register_forward_hook(self._feature_hook)
Expand Down Expand Up @@ -257,6 +266,31 @@ def __init__(
self._use_gaussian = use_gaussian
self._use_cls_token = use_cls_token

def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any:
"""Detect feature module in the first foward pass and register feature hook."""
# Make sure this hook called only 1 time
if detect_hook_handle := getattr(self, "_detect_hook_handle", None):
detect_hook_handle.remove()
delattr(self, "_detect_hook_handle")

# Find the 3rd last LayerNorm module during temp forward pass
self._feature_modules: list[torch.nn.Module] = []

def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
if isinstance(module, torch.nn.LayerNorm):
self._feature_modules.append(module)

global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook)
try:
module.forward(*inputs)
finally:
global_hook_handle.remove()
if len(self._feature_modules) < 3:
raise RuntimeError("Feature modules with LayerNorm is less than 3 in the torch model")

Check warning on line 289 in openvino_xai/methods/white_box/torch.py

View check run for this annotation

Codecov / codecov/patch

openvino_xai/methods/white_box/torch.py#L289

Added line #L289 was not covered by tests

# Set feature hook
self._feature_modules[-3].register_forward_hook(self._feature_hook)

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
feature_map = output
Expand Down
56 changes: 50 additions & 6 deletions tests/unit/methods/white_box/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ class DummyCNN(torch.nn.Module):
def __init__(self, num_classes: int = 2):
super().__init__()
self.num_classes = num_classes
self.feature = torch.nn.Identity()
self.feature = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
)
self.neck = torch.nn.AdaptiveAvgPool2d((1, 1))
self.output = torch.nn.LazyLinear(out_features=num_classes)

Expand All @@ -51,19 +56,33 @@ class DummyVIT(torch.nn.Module):
def __init__(self, num_classes: int = 2):
super().__init__()
self.num_classes = num_classes
self.feature = torch.nn.Identity()
self.pre = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
)
self.feature = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
)
self.output = torch.nn.LazyLinear(out_features=num_classes)
self.norm = None
self.norm1 = None

def forward(self, x: torch.Tensor):
b, c, h, w = x.shape
if not self.norm:
self.norm = torch.nn.LayerNorm(c)
if not self.norm1:
self.norm1 = torch.nn.LayerNorm(c)
self.norm2 = torch.nn.LayerNorm(c)
self.norm3 = torch.nn.LayerNorm(c)
x = self.pre(x)
x = x.reshape(b, c, h * w)
x = x.transpose(1, 2)
x = torch.cat([torch.rand((b, 1, c)), x], dim=1)
x = self.feature(x)
x = x + self.norm(x)
x = x + self.norm1(x)
x = x + self.norm2(x)
x = x + self.norm3(x)
x = self.output(x[:, 0])
return torch.nn.functional.softmax(x, dim=1)

Expand Down Expand Up @@ -136,6 +155,31 @@ def test_lazy_detect_feature_layer():
assert not hasattr(method, "_detect_hook_handle")
assert type(output) == dict
assert method._feature_module is model_xai.feature
output = method.model_forward(data)
assert type(output) == dict # still good for 2nd forward

model = DummyVIT()
method = TorchWhiteBoxMethod(model=model, target_layer=None)
model_xai = method.prepare_model()
assert hasattr(method, "_detect_hook_handle")
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
with pytest.raises(RuntimeError):
# 4D feature map search should fail for ViTs
output = method.model_forward(data)

model = DummyVIT()
method = TorchViTReciproCAM(model=model, target_layer=None)
model_xai = method.prepare_model()
assert hasattr(method, "_detect_hook_handle")
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
output = method.model_forward(data)
assert not hasattr(method, "_detect_hook_handle")
assert type(output) == dict
assert method._feature_modules[-3] is model_xai.norm1
output = method.model_forward(data)
assert type(output) == dict # still good for 2nd forward


def test_activationmap() -> None:
Expand Down

0 comments on commit 0603eea

Please sign in to comment.