-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c094d2d
commit aaa9115
Showing
3 changed files
with
141 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
141 changes: 141 additions & 0 deletions
141
build/lib/squeeze_and_excitation/squeeze_and_excitation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
""" | ||
Squeeze and Excitation Module | ||
***************************** | ||
Collection of squeeze and excitation classes where each can be inserted as a block into a neural network architechture | ||
1. `Channel Squeeze and Excitation <https://arxiv.org/abs/1709.01507>`_ | ||
2. `Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_ | ||
3. `Channel and Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_ | ||
""" | ||
|
||
from enum import Enum | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class ChannelSELayer(nn.Module): | ||
""" | ||
Re-implementation of Squeeze-and-Excitation (SE) block described in: | ||
*Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* | ||
""" | ||
|
||
def __init__(self, num_channels, reduction_ratio=2): | ||
""" | ||
:param num_channels: No of input channels | ||
:param reduction_ratio: By how much should the num_channels should be reduced | ||
""" | ||
super(ChannelSELayer, self).__init__() | ||
num_channels_reduced = num_channels // reduction_ratio | ||
self.reduction_ratio = reduction_ratio | ||
self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True) | ||
self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True) | ||
self.relu = nn.ReLU() | ||
self.sigmoid = nn.Sigmoid() | ||
|
||
def forward(self, input_tensor): | ||
""" | ||
:param input_tensor: X, shape = (batch_size, num_channels, H, W) | ||
:return: output tensor | ||
""" | ||
batch_size, num_channels, H, W = input_tensor.size() | ||
# Average along each channel | ||
squeeze_tensor = input_tensor.view(batch_size, num_channels, -1).mean(dim=2) | ||
|
||
# channel excitation | ||
fc_out_1 = self.relu(self.fc1(squeeze_tensor)) | ||
fc_out_2 = self.sigmoid(self.fc2(fc_out_1)) | ||
|
||
a, b = squeeze_tensor.size() | ||
output_tensor = torch.mul(input_tensor, fc_out_2.view(a, b, 1, 1)) | ||
return output_tensor | ||
|
||
|
||
class SpatialSELayer(nn.Module): | ||
""" | ||
Re-implementation of SE block -- squeezing spatially and exciting channel-wise described in: | ||
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018* | ||
""" | ||
|
||
def __init__(self, num_channels): | ||
""" | ||
:param num_channels: No of input channels | ||
""" | ||
super(SpatialSELayer, self).__init__() | ||
self.conv = nn.Conv2d(num_channels, 1, 1) | ||
self.sigmoid = nn.Sigmoid() | ||
|
||
def forward(self, input_tensor, weights=None): | ||
""" | ||
:param weights: weights for few shot learning | ||
:param input_tensor: X, shape = (batch_size, num_channels, H, W) | ||
:return: output_tensor | ||
""" | ||
# spatial squeeze | ||
batch_size, channel, a, b = input_tensor.size() | ||
|
||
if weights: | ||
weights = weights.view(1, channel, 1, 1) | ||
out = F.conv2d(input_tensor, weights) | ||
else: | ||
out = self.conv(input_tensor) | ||
squeeze_tensor = self.sigmoid(out) | ||
|
||
# spatial excitation | ||
output_tensor = torch.mul(input_tensor, squeeze_tensor.view(batch_size, 1, a, b)) | ||
|
||
return output_tensor | ||
|
||
|
||
class ChannelSpatialSELayer(nn.Module): | ||
""" | ||
Re-implementation of concurrent spatial and channel squeeze & excitation: | ||
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579* | ||
""" | ||
|
||
def __init__(self, num_channels, reduction_ratio=2): | ||
""" | ||
:param num_channels: No of input channels | ||
:param reduction_ratio: By how much should the num_channels should be reduced | ||
""" | ||
super(ChannelSpatialSELayer, self).__init__() | ||
self.cSE = ChannelSELayer(num_channels, reduction_ratio) | ||
self.sSE = SpatialSELayer(num_channels) | ||
|
||
def forward(self, input_tensor): | ||
""" | ||
:param input_tensor: X, shape = (batch_size, num_channels, H, W) | ||
:return: output_tensor | ||
""" | ||
output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor)) | ||
return output_tensor | ||
|
||
|
||
class SELayer(Enum): | ||
""" | ||
Enum restricting the type of SE Blockes available. So that type checking can be adding when adding these blockes to | ||
a neural network:: | ||
if self.se_block_type == se.SELayer.CSE.value: | ||
self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) | ||
elif self.se_block_type == se.SELayer.SSE.value: | ||
self.SELayer = se.SpatialSELayer(params['num_filters']) | ||
elif self.se_block_type == se.SELayer.CSSE.value: | ||
self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) | ||
""" | ||
NONE = 'NONE' | ||
CSE = 'CSE' | ||
SSE = 'SSE' | ||
CSSE = 'CSSE' |
Binary file not shown.