Skip to content

Commit

Permalink
Add integration test for auto layer detection
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Sep 9, 2024
1 parent 0603eea commit bac4363
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 64 deletions.
70 changes: 30 additions & 40 deletions openvino_xai/methods/white_box/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,33 @@ def __init__(
target_layer: str | None = None,
embed_scaling: bool = True,
device_name: str = "CPU",
prepare_model: bool = True,
**kwargs,
):
super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name)
self._target_layer = target_layer
self._embed_scaling = embed_scaling

if prepare_model:
self.prepare_model()

def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
"""Return XAI inserted model."""
if has_xai(self._model):
if load_model:
self._model_compiled = self._model
return self._model
if self._model_compiled is not None:
return self._model_compiled

model = copy.deepcopy(self._model)

# Feature
if self._target_layer:
feature_module = self._find_module_by_name(model, self._target_layer)
feature_module.register_forward_hook(self._feature_hook)
else:
self._detect_hook_handle = model.register_forward_pre_hook(self._lazy_detect_hook, prepend=True)
feature_module = self._find_feature_module(model)
feature_module.register_forward_hook(self._feature_hook)

# Output
model.register_forward_hook(self._output_hook)
Expand Down Expand Up @@ -102,18 +108,8 @@ def _find_module_by_name(self, model: torch.nn.Module, target_name: str) -> torc
raise ValueError(f"{target_name} not found in the torch model")
return target_module

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""Maipulate feature map for saliency map generation."""
self._feature_map = output
return output

def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any:
"""Detect feature module in the first foward pass and register feature hook."""
# Make sure this hook called only 1 time
if detect_hook_handle := getattr(self, "_detect_hook_handle", None):
detect_hook_handle.remove()
delattr(self, "_detect_hook_handle")

def _find_feature_module(self, module: torch.nn.Module) -> torch.nn.Module:
"""Detect feature module in the model."""
# Find the last layer that outputs 4D tensor during temp forward pass
self._feature_module = None
self._num_modules = 0
Expand All @@ -128,7 +124,7 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:

global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook)
try:
module.forward(*inputs)
module.forward(torch.zeros((1, 3, 128, 128)))
finally:
global_hook_handle.remove()
if self._feature_module is None:
Expand All @@ -138,8 +134,12 @@ def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%"
)

# Set feature hook
self._feature_module.register_forward_hook(self._feature_hook)
return self._feature_module

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""Maipulate feature map for saliency map generation."""
self._feature_map = output
return output

def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
Expand Down Expand Up @@ -189,8 +189,8 @@ class TorchReciproCAM(TorchWhiteBoxMethod):
"""

def __init__(self, *args, optimize_gap: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self._optimize_gap = optimize_gap
super().__init__(*args, **kwargs)

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
Expand Down Expand Up @@ -262,34 +262,24 @@ def __init__(
normalize: bool = True,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._use_gaussian = use_gaussian
self._use_cls_token = use_cls_token
super().__init__(*args, **kwargs)

def _lazy_detect_hook(self, module: torch.nn.Module, inputs: Any) -> Any:
"""Detect feature module in the first foward pass and register feature hook."""
# Make sure this hook called only 1 time
if detect_hook_handle := getattr(self, "_detect_hook_handle", None):
detect_hook_handle.remove()
delattr(self, "_detect_hook_handle")

# Find the 3rd last LayerNorm module during temp forward pass
self._feature_modules: list[torch.nn.Module] = []

def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
if isinstance(module, torch.nn.LayerNorm):
self._feature_modules.append(module)
def _find_feature_module(self, module: torch.nn.Module) -> torch.nn.Module:
"""Detect feature module in the model."""
# Find the 3rd last LayerNorm module
self._feature_module = None
feature_modules: list[torch.nn.Module] = []
for _, submodule in module.named_modules():
if isinstance(submodule, torch.nn.LayerNorm):
feature_modules.append(submodule)

global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook)
try:
module.forward(*inputs)
finally:
global_hook_handle.remove()
if len(self._feature_modules) < 3:
if len(feature_modules) < 3:
raise RuntimeError("Feature modules with LayerNorm is less than 3 in the torch model")

# Set feature hook
self._feature_modules[-3].register_forward_hook(self._feature_hook)
self._feature_module = feature_modules[-3]
return self._feature_module

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
Expand Down
10 changes: 7 additions & 3 deletions tests/intg/test_classification_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def test_model_format(self, model_id, explain_mode, model_format):
"deit_tiny_patch16_224.fb_in1k",
],
)
def test_torch_insert_xai_with_layer(self, model_id: str):
@pytest.mark.parametrize("detect", ["auto", "name"])
def test_torch_insert_xai_with_layer(self, model_id: str, detect: str):
xai_cfg = {
"resnet18.a1_in1k": ("layer4", Method.RECIPROCAM),
"efficientnet_b0.ra_in1k": ("bn2", Method.RECIPROCAM),
Expand All @@ -465,6 +466,9 @@ def test_torch_insert_xai_with_layer(self, model_id: str):
model_dir = self.data_dir / "timm_models" / "converted_models"
model, model_cfg = self.get_timm_model(model_id, model_dir)

target_layer = xai_cfg[model_id][0] if detect == "name" else None
explain_method = xai_cfg[model_id][1]

image = cv2.imread("tests/assets/cheetah_person.jpg")
image = cv2.resize(image, dsize=model_cfg["input_size"][1:])
image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB)
Expand All @@ -478,8 +482,8 @@ def test_torch_insert_xai_with_layer(self, model_id: str):
xai_model: torch.nn.Module = insert_xai(
model,
task=Task.CLASSIFICATION,
target_layer=xai_cfg[model_id][0],
explain_method=xai_cfg[model_id][1],
target_layer=target_layer,
explain_method=explain_method,
)

with torch.no_grad():
Expand Down
35 changes: 14 additions & 21 deletions tests/unit/methods/white_box/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def forward(self, x: torch.Tensor):


class DummyVIT(torch.nn.Module):
def __init__(self, num_classes: int = 2):
def __init__(self, num_classes: int = 2, dim: int = 3):
super().__init__()
self.num_classes = num_classes
self.dim = dim
self.pre = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
Expand All @@ -66,15 +67,13 @@ def __init__(self, num_classes: int = 2):
torch.nn.Identity(),
torch.nn.Identity(),
)
self.norm1 = torch.nn.LayerNorm(dim)
self.norm2 = torch.nn.LayerNorm(dim)
self.norm3 = torch.nn.LayerNorm(dim)
self.output = torch.nn.LazyLinear(out_features=num_classes)
self.norm1 = None

def forward(self, x: torch.Tensor):
b, c, h, w = x.shape
if not self.norm1:
self.norm1 = torch.nn.LayerNorm(c)
self.norm2 = torch.nn.LayerNorm(c)
self.norm3 = torch.nn.LayerNorm(c)
x = self.pre(x)
x = x.reshape(b, c, h * w)
x = x.transpose(1, 2)
Expand Down Expand Up @@ -129,7 +128,9 @@ def _output_hook(

def test_prepare_model():
model = DummyCNN()
method = TorchWhiteBoxMethod(model=model, target_layer="feature")
method = TorchWhiteBoxMethod(model=model, target_layer="feature", prepare_model=False)
model_xai = method.prepare_model(load_model=False)
assert method._model_compiled is None
model_xai = method.prepare_model(load_model=False)
assert method._model_compiled is None
assert model is not model_xai
Expand All @@ -144,40 +145,31 @@ def test_prepare_model():
assert model_xai == model


def test_lazy_detect_feature_layer():
def test_detect_feature_layer():
model = DummyCNN()
method = TorchWhiteBoxMethod(model=model, target_layer=None)
model_xai = method.prepare_model()
assert hasattr(method, "_detect_hook_handle")
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
output = method.model_forward(data)
assert not hasattr(method, "_detect_hook_handle")
assert type(output) == dict
assert method._feature_module is model_xai.feature
output = method.model_forward(data)
assert type(output) == dict # still good for 2nd forward

model = DummyVIT()
method = TorchWhiteBoxMethod(model=model, target_layer=None)
model_xai = method.prepare_model()
assert hasattr(method, "_detect_hook_handle")
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
with pytest.raises(RuntimeError):
# 4D feature map search should fail for ViTs
output = method.model_forward(data)
method = TorchWhiteBoxMethod(model=model, target_layer=None)

model = DummyVIT()
method = TorchViTReciproCAM(model=model, target_layer=None)
model_xai = method.prepare_model()
assert hasattr(method, "_detect_hook_handle")
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
output = method.model_forward(data)
assert not hasattr(method, "_detect_hook_handle")
assert type(output) == dict
assert method._feature_modules[-3] is model_xai.norm1
assert method._feature_module is model_xai.norm1
output = method.model_forward(data)
assert type(output) == dict # still good for 2nd forward

Expand Down Expand Up @@ -222,13 +214,14 @@ def test_reciprocam(optimize_gap: bool) -> None:
def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None:
batch_size = 2
num_classes = 3
model = DummyVIT(num_classes=num_classes)
dim = 3
model = DummyVIT(num_classes=num_classes, dim=dim)
method = TorchViTReciproCAM(
model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token
)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(batch_size, 4, 5, 5)
data = np.random.rand(batch_size, dim, 5, 5)
output = method.model_forward(data)
assert type(output) == dict
saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME]
Expand Down

0 comments on commit bac4363

Please sign in to comment.