-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
FrancescoSaverioZuppichini
committed
Dec 28, 2018
1 parent
96a74c4
commit 707bb82
Showing
15 changed files
with
476 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified
BIN
-10 Bytes
(99%)
mirror/visualisations/__pycache__/Visualisation.cpython-36.pyc
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
class Base: | ||
def __init__(self, module, device): | ||
self.module, self.device = module, device | ||
self.handles = [] | ||
|
||
def clean(self): | ||
[h.remove() for h in self.handles] | ||
|
||
|
||
def __call__(self, inputs, layer, *args, **kwargs): | ||
return inputs, {} | ||
|
||
class LayerFeatures(Base): | ||
def __init__(self, layer, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.grads, self.outputs = None, None | ||
self.layer = layer | ||
|
||
def store_grads(self): | ||
def hook(module, grad_in, grad_out): | ||
self.clean() | ||
self.grads = grad_in[0] | ||
|
||
self.handles.append(self.layer.register_backward_hook(hook)) | ||
|
||
def store_outputs(self): | ||
def hook(module, inputs, outputs): | ||
self.clean() | ||
self.outputs = outputs | ||
|
||
self.handles.append(self.layer.register_backward_hook(hook)) | ||
|
||
@property | ||
def has_grads(self): | ||
return self.grads is not None | ||
|
||
@property | ||
def has_outputs(self): | ||
return self.outputs is not None | ||
|
||
class Visualisation(Base): | ||
|
||
def trace(self, module, inputs): | ||
self.modules = [] | ||
|
||
def trace(module, inputs, outputs): | ||
self.modules.append(module) | ||
|
||
def traverse(module): | ||
for m in module.children(): | ||
traverse(m) | ||
is_leaf = len(list(module.children())) == 0 | ||
if is_leaf: self.handles.append(module.register_forward_hook(trace)) | ||
|
||
traverse(module) | ||
|
||
_ = module(inputs) | ||
|
||
self.clean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
|
||
from torch.nn import AvgPool2d, Conv2d, Linear, ReLU | ||
from torch.nn.functional import softmax | ||
|
||
from .Base import Base | ||
|
||
from .utils import module2traced, imshow, tensor2cam | ||
|
||
import torch.nn.functional as F | ||
|
||
|
||
class ClassActivationMapping(Base): | ||
""" | ||
Based on Learning Deep Features for Discriminative Localization (https://arxiv.org/abs/1512.04150). | ||
Be aware,it requires feature maps to directly precede softmax layers. | ||
It will work for resnet but not for alexnet for example | ||
""" | ||
|
||
def __call__(self, inputs, layer, target_class=None, postprocessing=lambda x: x, guide=False): | ||
modules = module2traced(self.module, inputs) | ||
last_conv = None | ||
last_linear = None | ||
|
||
for i, module in enumerate(modules): | ||
if isinstance(module, Conv2d): | ||
last_conv = module | ||
if isinstance(module, AvgPool2d): | ||
pass | ||
if isinstance(module, Linear): | ||
last_linear = module | ||
|
||
def store_conv_outputs(module, inputs, outputs): | ||
self.conv_outputs = outputs | ||
|
||
last_conv.register_forward_hook(store_conv_outputs) | ||
|
||
predictions = self.module(inputs) | ||
|
||
if target_class == None: _, target_class = torch.max(predictions, dim=1) | ||
_, c, h, w = self.conv_outputs.shape | ||
# get the weights relative to the target class | ||
fc_weights_class = last_linear.weight.data[target_class] | ||
# sum up the multiplication of each weight w_k for the relative channel in the last | ||
# convolution output | ||
cam = fc_weights_class @ self.conv_outputs.view((c, h * w)) | ||
cam = cam.view(h, w) | ||
|
||
with torch.no_grad(): | ||
image_with_heatmap = tensor2cam(postprocessing(inputs.squeeze()), cam) | ||
|
||
return image_with_heatmap.unsqueeze(0), { 'prediction': target_class } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import torch | ||
import torchvision.transforms.functional as TF | ||
|
||
from torch.autograd import Variable | ||
from PIL import Image, ImageFilter, ImageChops | ||
from .Base import Base | ||
from .utils import image_net_postprocessing, \ | ||
image_net_preprocessing | ||
|
||
class DeepDream(Base): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.handle = None | ||
|
||
def register_hooks(self): | ||
if self.handle: self.handle.remove() | ||
|
||
def hook(module, input, output): | ||
if module == self.layer: | ||
self.layer_output = output | ||
|
||
self.optimizer.zero_grad() | ||
loss = -torch.norm(self.layer_output) | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
raise Exception('Layer found!') | ||
|
||
return self.layer.register_forward_hook(hook) | ||
|
||
def step(self, image, steps=5, save=False): | ||
|
||
self.module.zero_grad() | ||
image_pre = image_net_preprocessing(image.squeeze().cpu()).to(self.device).unsqueeze(0) | ||
self.image_var = Variable(image_pre, requires_grad=True).to(self.device) | ||
|
||
self.optimizer = torch.optim.Adam([self.image_var], lr=self.lr) | ||
|
||
for i in range(steps): | ||
try: | ||
self.module(self.image_var) | ||
except: | ||
pass | ||
|
||
dreamed = self.image_var.data.squeeze() | ||
c, w, h = dreamed.shape | ||
|
||
# dreamed = dreamed.view((w, h, c)) | ||
dreamed = image_net_postprocessing(dreamed.cpu()).to(self.device) | ||
# dreamed = dreamed * self.std + self.mean | ||
dreamed = torch.clamp(dreamed, 0.0, 1.0) | ||
# dreamed = dreamed.view((c, w, h)) | ||
|
||
del self.image_var, image_pre | ||
|
||
return dreamed | ||
|
||
def deep_dream(self, image, n, top, scale_factor): | ||
if n > 0: | ||
b, c, w, h = image.shape | ||
# print(w,h) | ||
image = TF.to_pil_image(image.squeeze().cpu()) | ||
image_down = TF.resize(image, (int(w * scale_factor), int(h * scale_factor)), Image.ANTIALIAS) | ||
image_down = image_down.filter(ImageFilter.GaussianBlur(0.5)) | ||
|
||
image_down = TF.to_tensor(image_down).unsqueeze(0) | ||
from_down = self.deep_dream(image_down, n - 1, top, scale_factor) | ||
|
||
from_down = TF.to_pil_image(from_down.squeeze().cpu()) | ||
from_down = TF.resize(from_down, (w, h), Image.ANTIALIAS) | ||
|
||
image = ImageChops.blend(from_down, image, 0.6) | ||
|
||
image = TF.to_tensor(image).to(self.device) | ||
n = n - 1 | ||
|
||
return self.step(image, steps=8, save=top == n + 1) | ||
|
||
def __call__(self, inputs, layer, octaves=6, scale_factor=0.7, lr=0.1): | ||
self.layer, self.lr = layer, lr | ||
self.handle = self.register_hooks() | ||
self.module.zero_grad() | ||
|
||
dd = self.deep_dream(inputs, octaves, | ||
top=octaves, | ||
scale_factor=scale_factor) | ||
self.handle.remove() | ||
|
||
return dd.unsqueeze(0), {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import cv2 | ||
import numpy as np | ||
import torch | ||
|
||
from torch.nn import ReLU | ||
from torch.autograd import Variable | ||
from .Base import Base | ||
from torch.nn import AvgPool2d, Conv2d, Linear, ReLU, MaxPool2d, BatchNorm2d | ||
import torch.nn.functional as F | ||
|
||
from .utils import tensor2cam, module2traced, imshow | ||
|
||
class GradCam(Base): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.handles = [] | ||
self.gradients = None | ||
self.conv_outputs = None | ||
|
||
def store_outputs_and_grad(self, layer): | ||
def store_grads(module, grad_in, grad_out): | ||
self.gradients = grad_out[0] | ||
|
||
def store_outputs(module, input, outputs): | ||
if module == layer: | ||
self.conv_outputs = outputs | ||
|
||
self.handles.append(layer.register_forward_hook(store_outputs)) | ||
self.handles.append(layer.register_backward_hook(store_grads)) | ||
|
||
def guide(self, module): | ||
def guide_relu(module, grad_in, grad_out): | ||
return (torch.clamp(grad_out[0], min=0.0),) | ||
|
||
for module in module.modules(): | ||
if isinstance(module, ReLU): | ||
self.handles.append(module.register_backward_hook(guide_relu)) | ||
|
||
def __call__(self, input_image, layer, guide=False, target_class=None, postprocessing=lambda x: x): | ||
self.clean() | ||
self.module.zero_grad() | ||
|
||
if layer is None: | ||
modules = module2traced(self.module, input_image) | ||
for i, module in enumerate(modules): | ||
if isinstance(module, Conv2d): | ||
layer = module | ||
|
||
self.store_outputs_and_grad(layer) | ||
|
||
if guide: self.guide(self.module) | ||
|
||
input_var = Variable(input_image, requires_grad=True).to(self.device) | ||
predictions = self.module(input_var) | ||
|
||
if target_class is None: values, target_class = torch.max(predictions, dim=1) | ||
|
||
target = torch.zeros(predictions.size()).to(self.device) | ||
target[0][target_class] = 1 | ||
|
||
predictions.backward(gradient=target, retain_graph=True) | ||
|
||
with torch.no_grad(): | ||
avg_channel_grad = F.adaptive_avg_pool2d(self.gradients.data, 1) | ||
cam = F.relu(torch.sum(self.conv_outputs[0] * avg_channel_grad[0], dim=0)) | ||
image_with_heatmap = tensor2cam(postprocessing(input_image.squeeze()), cam) | ||
|
||
self.clean() | ||
|
||
return image_with_heatmap.unsqueeze(0), { 'prediction': target_class } | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# cnn-visualisations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
|
||
from .Base import Base | ||
from torch.nn import ReLU | ||
from torch.autograd import Variable | ||
from torchvision.transforms import * | ||
from .utils import convert_to_grayscale | ||
|
||
class SaliencyMap(Base): | ||
""" | ||
Simonyan, Vedaldi, and Zisserman, “Deep Inside Convolutional Networks: Visualising Image Classification Models | ||
and Saliency Maps”, ICLR Workshop 2014 | ||
https://arxiv.org/abs/1312.6034 | ||
""" | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.gradients = None | ||
self.handles = [] | ||
self.stored_grad = False | ||
|
||
def store_first_layer_grad(self): | ||
|
||
def hook_grad_input(module, inputs, outputs): | ||
# stored only for the first time -> first layer | ||
if not self.stored_grad: | ||
self.handles.append(module.register_backward_hook(store_grad)) | ||
self.stored_grad = True | ||
|
||
def store_grad(module, grad_in, grad_out): | ||
self.gradients = grad_in[0] | ||
|
||
for module in self.module.modules(): | ||
self.handles.append(module.register_forward_hook(hook_grad_input)) | ||
|
||
def guide(self, module): | ||
def guide_relu(module, grad_in, grad_out): | ||
return (torch.clamp(grad_in[0], min=0.0),) | ||
|
||
for module in module.modules(): | ||
if isinstance(module, ReLU): | ||
self.handles.append(module.register_backward_hook(guide_relu)) | ||
|
||
def __call__(self, input_image, layer, guide=False, target_class=None): | ||
self.stored_grad = False | ||
|
||
self.clean() | ||
if guide: self.guide(self.module) | ||
|
||
input_image = Variable(input_image, requires_grad=True).to(self.device) | ||
|
||
self.store_first_layer_grad() | ||
|
||
predictions = self.module(input_image) | ||
|
||
if target_class == None: _, target_class = torch.max(predictions, dim=1) | ||
|
||
one_hot_output = torch.zeros(predictions.size()).to(self.device) | ||
one_hot_output[0][target_class] = 1 | ||
|
||
self.module.zero_grad() | ||
|
||
predictions.backward(gradient=one_hot_output) | ||
|
||
image = self.gradients.data.cpu().numpy()[0] | ||
|
||
image = convert_to_grayscale(image) | ||
image = torch.from_numpy(image).to(self.device) | ||
|
||
self.clean() | ||
|
||
return image.unsqueeze(0), { 'prediction': target_class } | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from .Base import Base | ||
|
||
class Weights(Base): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.outputs = None | ||
|
||
def hook(self, module, input, output): | ||
self.clean() | ||
self.outputs = output | ||
|
||
def __call__(self, inputs, layer, *args, **kwargs): | ||
self.handles.append(layer.register_forward_hook(self.hook)) | ||
self.module(inputs) | ||
b, c, h, w = self.outputs.shape | ||
# reshape to make an array of images 1-Channel | ||
outputs = self.outputs.view(c, b, h, w) | ||
return outputs, {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .SaliencyMap import SaliencyMap | ||
from .DeepDream import DeepDream | ||
from .GradCam import GradCam | ||
from .Weights import Weights | ||
from .Base import Base | ||
from .ClassActivationMapping import ClassActivationMapping |
Oops, something went wrong.