Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-detect feature layer for Pytorch models #64

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 75 additions & 11 deletions openvino_xai/methods/white_box/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,37 @@
target_layer: str | None = None,
embed_scaling: bool = True,
device_name: str = "CPU",
prepare_model: bool = True,
**kwargs,
):
super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name)
self._target_layer = target_layer
self._embed_scaling = embed_scaling

if prepare_model:
self.prepare_model()

def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
"""Return XAI inserted model."""
if has_xai(self._model):
if load_model:
self._model_compiled = self._model
return self._model
if self._model_compiled is not None:
return self._model_compiled

model = copy.deepcopy(self._model)

# Feature
feature_layer = model.get_submodule(self._target_layer)
feature_layer.register_forward_hook(self._feature_hook)
if self._target_layer:
feature_module = self._find_feature_module_by_name(model, self._target_layer)
else:
feature_module = self._find_feature_module_auto(model)
feature_module.register_forward_hook(self._feature_hook)

# Output
model.register_forward_hook(self._output_hook)

setattr(model, "has_xai", True)
model.eval()

Expand All @@ -86,11 +98,51 @@
output[name] = data.numpy(force=True)
return output

def _find_feature_module_by_name(self, model: torch.nn.Module, target_name: str) -> torch.nn.Module:
"""Search the last layer by name sub string match."""
target_module = None
for name, module in model.named_modules():
if target_name in name:
target_module = module
if target_module is None:
raise ValueError(f"{target_name} is not found in the torch model")
return target_module

def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module:
"""Detect feature module in the model."""
# Find the last layer that outputs 4D tensor during temp forward pass
self._feature_module = None
self._num_modules = 0

def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
if isinstance(output, torch.Tensor):
module.index = self._num_modules
self._num_modules += 1
shape = output.shape
if len(shape) == 4 and shape[2] > 1 and shape[3] > 1:
self._feature_module = module

global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook)
try:
module.forward(torch.zeros((1, 3, 128, 128)))
finally:
global_hook_handle.remove()
if self._feature_module is None:
raise RuntimeError("Feature module with 4D output is not found in the torch model")

Check warning on line 131 in openvino_xai/methods/white_box/torch.py

View check run for this annotation

Codecov / codecov/patch

openvino_xai/methods/white_box/torch.py#L131

Added line #L131 was not covered by tests
if self._feature_module.index / self._num_modules < 0.5: # Check if ViT-like architectures
goodsong81 marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%"
)

return self._feature_module

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""Manipulate feature map for saliency map generation."""
self._feature_map = output
return output

def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
return {
"prediction": output,
SALIENCY_MAP_OUTPUT_NAME: torch.empty_like(output),
Expand Down Expand Up @@ -137,8 +189,8 @@
"""

def __init__(self, *args, optimize_gap: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self._optimize_gap = optimize_gap
super().__init__(*args, **kwargs)

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
Expand All @@ -153,16 +205,17 @@
return torch.cat(feature_maps)

def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
batch_size, _, h, w = self._feature_shape
num_classes = output.shape[1]
predictions = output[:batch_size]
saliency_maps = output[batch_size:]
saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes])
saliency_maps = saliency_maps.transpose(1, 2) # BxHWxC -> BxCxHW
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
batch_size, _, h, w = self._feature_shape # B0xDxHxW
goodsong81 marked this conversation as resolved.
Show resolved Hide resolved
num_classes = output.shape[1] # C
predictions = output[:batch_size] # BxC
saliency_maps = output[batch_size:] # BHWxC
saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) # BxHWxC
saliency_maps = saliency_maps.transpose(1, 2) # BxCxHW
if self._embed_scaling:
saliency_maps = saliency_maps.reshape((batch_size * num_classes, h * w))
saliency_maps = self._normalize_map(saliency_maps)
saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w])
saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) # BxCxHxW
return {
"prediction": predictions,
SALIENCY_MAP_OUTPUT_NAME: saliency_maps,
Expand Down Expand Up @@ -209,9 +262,20 @@
normalize: bool = True,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._use_gaussian = use_gaussian
self._use_cls_token = use_cls_token
super().__init__(*args, **kwargs)

def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module:
"""Detect feature module in the model by finding the 3rd last LayerNorm module."""
self._feature_module = None
norm_modules = [m for _, m in module.named_modules() if isinstance(m, torch.nn.LayerNorm)]

if len(norm_modules) < 3:
raise RuntimeError("Feature modules with LayerNorm are less than 3 in the torch model")

Check warning on line 275 in openvino_xai/methods/white_box/torch.py

View check run for this annotation

Codecov / codecov/patch

openvino_xai/methods/white_box/torch.py#L275

Added line #L275 was not covered by tests

self._feature_module = norm_modules[-3]
return self._feature_module

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
Expand Down
10 changes: 7 additions & 3 deletions tests/intg/test_classification_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def test_model_format(self, model_id, explain_mode, model_format):
"deit_tiny_patch16_224.fb_in1k",
],
)
def test_torch_insert_xai_with_layer(self, model_id: str):
@pytest.mark.parametrize("detect", ["auto", "name"])
def test_torch_insert_xai_with_layer(self, model_id: str, detect: str):
xai_cfg = {
"resnet18.a1_in1k": ("layer4", Method.RECIPROCAM),
"efficientnet_b0.ra_in1k": ("bn2", Method.RECIPROCAM),
Expand All @@ -465,6 +466,9 @@ def test_torch_insert_xai_with_layer(self, model_id: str):
model_dir = self.data_dir / "timm_models" / "converted_models"
model, model_cfg = self.get_timm_model(model_id, model_dir)

target_layer = xai_cfg[model_id][0] if detect == "name" else None
explain_method = xai_cfg[model_id][1]

image = cv2.imread("tests/assets/cheetah_person.jpg")
image = cv2.resize(image, dsize=model_cfg["input_size"][1:])
image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB)
Expand All @@ -478,8 +482,8 @@ def test_torch_insert_xai_with_layer(self, model_id: str):
xai_model: torch.nn.Module = insert_xai(
model,
task=Task.CLASSIFICATION,
target_layer=xai_cfg[model_id][0],
explain_method=xai_cfg[model_id][1],
target_layer=target_layer,
explain_method=explain_method,
)

with torch.no_grad():
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/methods/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_create_wb_det_cnn_method(fxt_data_root: Path):
assert str(exc_info.value) == "Requested explanation method abc is not implemented."


def test_create_torch_method():
def test_create_torch_method(mocker: MockerFixture):
model = {}
with pytest.raises(ValueError):
explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn())
Expand All @@ -172,6 +172,10 @@ def test_create_torch_method():
Task.DETECTION, model, get_postprocess_fn(), target_layer=""
)

mocker.patch.object(torch_method.TorchActivationMap, "prepare_model")
mocker.patch.object(torch_method.TorchReciproCAM, "prepare_model")
mocker.patch.object(torch_method.TorchViTReciproCAM, "prepare_model")

model = torch.nn.Module()
explain_method = WhiteBoxMethodFactory.create_method(
Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.ACTIVATIONMAP
Expand Down
71 changes: 65 additions & 6 deletions tests/unit/methods/white_box/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ class DummyCNN(torch.nn.Module):
def __init__(self, num_classes: int = 2):
super().__init__()
self.num_classes = num_classes
self.feature = torch.nn.Identity()
self.feature = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
)
self.neck = torch.nn.AdaptiveAvgPool2d((1, 1))
self.output = torch.nn.LazyLinear(out_features=num_classes)

Expand All @@ -48,24 +53,46 @@ def forward(self, x: torch.Tensor):


class DummyVIT(torch.nn.Module):
def __init__(self, num_classes: int = 2):
def __init__(self, num_classes: int = 2, dim: int = 3):
super().__init__()
self.num_classes = num_classes
self.feature = torch.nn.Identity()
self.dim = dim
self.pre = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
)
self.feature = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
)
self.norm1 = torch.nn.LayerNorm(dim)
self.norm2 = torch.nn.LayerNorm(dim)
self.norm3 = torch.nn.LayerNorm(dim)
self.output = torch.nn.LazyLinear(out_features=num_classes)

def forward(self, x: torch.Tensor):
b, c, h, w = x.shape
x = self.pre(x)
x = x.reshape(b, c, h * w)
x = x.transpose(1, 2)
x = torch.cat([torch.rand((b, 1, c)), x], dim=1)
x = self.feature(x)
x = x + self.norm1(x)
x = x + self.norm2(x)
x = x + self.norm3(x)
x = self.output(x[:, 0])
return torch.nn.functional.softmax(x, dim=1)


def test_torch_method():
model = DummyCNN()

with pytest.raises(ValueError):
method = TorchWhiteBoxMethod(model=model, target_layer="something_else")
model_xai = method.prepare_model()

method = TorchWhiteBoxMethod(model=model, target_layer="feature")
model_xai = method.prepare_model()
assert has_xai(model_xai)
Expand Down Expand Up @@ -101,7 +128,9 @@ def _output_hook(

def test_prepare_model():
model = DummyCNN()
method = TorchWhiteBoxMethod(model=model, target_layer="feature")
method = TorchWhiteBoxMethod(model=model, target_layer="feature", prepare_model=False)
model_xai = method.prepare_model(load_model=False)
assert method._model_compiled is None
model_xai = method.prepare_model(load_model=False)
assert method._model_compiled is None
assert model is not model_xai
Expand All @@ -116,6 +145,35 @@ def test_prepare_model():
assert model_xai == model


def test_detect_feature_layer():
model = DummyCNN()
method = TorchWhiteBoxMethod(model=model, target_layer=None)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
output = method.model_forward(data)
assert type(output) == dict
assert method._feature_module is model_xai.feature
output = method.model_forward(data)
assert type(output) == dict # still good for 2nd forward

model = DummyVIT()
with pytest.raises(RuntimeError):
# 4D feature map search should fail for ViTs
method = TorchWhiteBoxMethod(model=model, target_layer=None)

model = DummyVIT()
method = TorchViTReciproCAM(model=model, target_layer=None)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
output = method.model_forward(data)
assert type(output) == dict
assert method._feature_module is model_xai.norm1
output = method.model_forward(data)
assert type(output) == dict # still good for 2nd forward


def test_activationmap() -> None:
batch_size = 2
num_classes = 3
Expand Down Expand Up @@ -156,13 +214,14 @@ def test_reciprocam(optimize_gap: bool) -> None:
def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None:
batch_size = 2
num_classes = 3
model = DummyVIT(num_classes=num_classes)
dim = 3
model = DummyVIT(num_classes=num_classes, dim=dim)
method = TorchViTReciproCAM(
model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token
)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(batch_size, 4, 5, 5)
data = np.random.rand(batch_size, dim, 5, 5)
output = method.model_forward(data)
assert type(output) == dict
saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME]
Expand Down