import torch
import torch.nn as nn
import torch.nn.functional as F
class GatedConv2dWithActivation(torch.nn.Module):
"""
Gated Convlution layer with activation (default activation:LeakyReLU)
Params: same as conv2d
Input: The feature from last layer "I"
Output:\phi(f(I))*\sigmoid(g(I))
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,batch_norm=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(GatedConv2dWithActivation, self).__init__()
self.batch_norm = batch_norm
self.activation = activation
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.batch_norm2d = torch.nn.BatchNorm2d(out_channels)
self.sigmoid = torch.nn.Sigmoid()
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def gated(self, mask):
return self.sigmoid(mask)
def forward(self, input):
x = self.conv2d(input)
mask = self.mask_conv2d(input)
if self.activation is not None:
x = self.activation(x) * self.gated(mask)
else:
x = x * self.gated(mask)
if self.batch_norm:
return self.batch_norm2d(x)
else:
return x
class GatedDeConv2dWithActivation(torch.nn.Module):
"""
Gated DeConvlution layer with activation (default activation:LeakyReLU)
resize + conv
Params: same as conv2d
Input: The feature from last layer "I"
Output:\phi(f(I))*\sigmoid(g(I))
"""
def __init__(self, scale_factor, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=True,activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(GatedDeConv2dWithActivation, self).__init__()
self.conv2d = GatedConv2dWithActivation(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, batch_norm, activation)
self.scale_factor = scale_factor
def forward(self, input):
#print(input.size())
x = F.interpolate(input, scale_factor=2)
return self.conv2d(x)
class SNGatedConv2dWithActivation(torch.nn.Module):
"""
Gated Convolution with spetral normalization
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(SNGatedConv2dWithActivation, self).__init__()
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.activation = activation
self.batch_norm = batch_norm
self.batch_norm2d = torch.nn.BatchNorm2d(out_channels)
self.sigmoid = torch.nn.Sigmoid()
self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
self.mask_conv2d = torch.nn.utils.spectral_norm(self.mask_conv2d)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def gated(self, mask):
return self.sigmoid(mask)
def forward(self, input):
x = self.conv2d(input)
mask = self.mask_conv2d(input)
if self.activation is not None:
x = self.activation(x) * self.gated(mask)
else:
x = x * self.gated(mask)
if self.batch_norm:
return self.batch_norm2d(x)
else:
return x
class SNGatedDeConv2dWithActivation(torch.nn.Module):
"""
Gated DeConvlution layer with activation (default activation:LeakyReLU)
resize + conv
Params: same as conv2d
Input: The feature from last layer "I"
Output:\phi(f(I))*\sigmoid(g(I))
"""
def __init__(self, scale_factor, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
super(SNGatedDeConv2dWithActivation, self).__init__()
self.conv2d = SNGatedConv2dWithActivation(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, batch_norm, activation)
self.scale_factor = scale_factor
def forward(self, input):
#print(input.size())
x = F.interpolate(input, scale_factor=2)
return self.conv2d(x)