diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..724893e --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +models +models.zip +models_pytorch.zip + +# python artifacts +__pychache__ +*.pyc + +# Temporary Files +*~ +*.swp +*.swo diff --git a/Readme.md b/Readme.md index af978c2..bed764d 100644 --- a/Readme.md +++ b/Readme.md @@ -7,7 +7,7 @@ Official Torch implementation can be found [here](https://github.com/Yijunmaveri ## Prerequisites - [Pytorch](http://pytorch.org/) - [torchvision](https://github.com/pytorch/vision) -- Pretrained encoder and decoder [models](https://drive.google.com/file/d/1M5KBPfqrIUZqrBZf78CIxLrMUT4lD4t9/view?usp=sharing) for image reconstruction only (download and uncompress them under models/) +- Pretrained encoder and decoder [models](http://pascal.inrialpes.fr/data2/archetypal_style/models_pytorch.zip) for image reconstruction only (download and uncompress them under models/) - CUDA + CuDNN ## Prepare images diff --git a/WCT.py b/WCT.py index 9c5c141..9d7c5fc 100644 --- a/WCT.py +++ b/WCT.py @@ -8,23 +8,18 @@ from Loader import Dataset from util import * import scipy.misc -from torch.utils.serialization import load_lua import time parser = argparse.ArgumentParser(description='WCT Pytorch') parser.add_argument('--contentPath',default='images/content',help='path to train') parser.add_argument('--stylePath',default='images/style',help='path to train') parser.add_argument('--workers', default=2, type=int, metavar='N',help='number of data loading workers (default: 4)') -parser.add_argument('--vgg1', default='models/vgg_normalised_conv1_1.t7', help='Path to the VGG conv1_1') -parser.add_argument('--vgg2', default='models/vgg_normalised_conv2_1.t7', help='Path to the VGG conv2_1') -parser.add_argument('--vgg3', default='models/vgg_normalised_conv3_1.t7', help='Path to the VGG conv3_1') -parser.add_argument('--vgg4', default='models/vgg_normalised_conv4_1.t7', help='Path to the VGG conv4_1') -parser.add_argument('--vgg5', default='models/vgg_normalised_conv5_1.t7', help='Path to the VGG conv5_1') -parser.add_argument('--decoder5', default='models/feature_invertor_conv5_1.t7', help='Path to the decoder5') -parser.add_argument('--decoder4', default='models/feature_invertor_conv4_1.t7', help='Path to the decoder4') -parser.add_argument('--decoder3', default='models/feature_invertor_conv3_1.t7', help='Path to the decoder3') -parser.add_argument('--decoder2', default='models/feature_invertor_conv2_1.t7', help='Path to the decoder2') -parser.add_argument('--decoder1', default='models/feature_invertor_conv1_1.t7', help='Path to the decoder1') +parser.add_argument('--encoder', default='models/vgg19_normalized.pth.tar', help='Path to the VGG conv1_1') +parser.add_argument('--decoder5', default='models/vgg19_normalized_decoder5.pth.tar', help='Path to the decoder5') +parser.add_argument('--decoder4', default='models/vgg19_normalized_decoder4.pth.tar', help='Path to the decoder4') +parser.add_argument('--decoder3', default='models/vgg19_normalized_decoder3.pth.tar', help='Path to the decoder3') +parser.add_argument('--decoder2', default='models/vgg19_normalized_decoder2.pth.tar', help='Path to the decoder2') +parser.add_argument('--decoder1', default='models/vgg19_normalized_decoder1.pth.tar', help='Path to the decoder1') parser.add_argument('--cuda', action='store_true', help='enables cuda') parser.add_argument('--batch_size', type=int, default=1, help='batch size') parser.add_argument('--fineSize', type=int, default=512, help='resize image to fineSize x fineSize,leave it to 0 if not resize') @@ -45,70 +40,75 @@ batch_size=1, shuffle=False) -wct = WCT(args) -def styleTransfer(contentImg,styleImg,imname,csF): +def styleTransfer(wct, contentImg, styleImg, imname): - sF5 = wct.e5(styleImg) - cF5 = wct.e5(contentImg) - sF5 = sF5.data.cpu().squeeze(0) - cF5 = cF5.data.cpu().squeeze(0) - csF5 = wct.transform(cF5,sF5,csF,args.alpha) + sF5 = wct.encoder(styleImg, 'relu5_1') + cF5 = wct.encoder(contentImg, 'relu5_1') + sF5 = sF5.cpu().squeeze(0) + cF5 = cF5.cpu().squeeze(0) + csF5 = wct.transform(cF5,sF5,args.alpha) + csF5 = csF5.to(next(wct.parameters()).device) Im5 = wct.d5(csF5) - sF4 = wct.e4(styleImg) - cF4 = wct.e4(Im5) - sF4 = sF4.data.cpu().squeeze(0) - cF4 = cF4.data.cpu().squeeze(0) - csF4 = wct.transform(cF4,sF4,csF,args.alpha) + sF4 = wct.encoder(styleImg, 'relu4_1') + cF4 = wct.encoder(Im5, 'relu4_1') + sF4 = sF4.cpu().squeeze(0) + cF4 = cF4.cpu().squeeze(0) + csF4 = wct.transform(cF4,sF4,args.alpha) + csF4 = csF4.to(next(wct.parameters()).device) Im4 = wct.d4(csF4) - sF3 = wct.e3(styleImg) - cF3 = wct.e3(Im4) - sF3 = sF3.data.cpu().squeeze(0) - cF3 = cF3.data.cpu().squeeze(0) - csF3 = wct.transform(cF3,sF3,csF,args.alpha) + sF3 = wct.encoder(styleImg, 'relu3_1') + cF3 = wct.encoder(Im4, 'relu3_1') + sF3 = sF3.cpu().squeeze(0) + cF3 = cF3.cpu().squeeze(0) + csF3 = wct.transform(cF3,sF3,args.alpha) + csF3 = csF3.to(next(wct.parameters()).device) Im3 = wct.d3(csF3) - sF2 = wct.e2(styleImg) - cF2 = wct.e2(Im3) - sF2 = sF2.data.cpu().squeeze(0) - cF2 = cF2.data.cpu().squeeze(0) - csF2 = wct.transform(cF2,sF2,csF,args.alpha) + sF2 = wct.encoder(styleImg, 'relu2_1') + cF2 = wct.encoder(Im3, 'relu2_1') + sF2 = sF2.cpu().squeeze(0) + cF2 = cF2.cpu().squeeze(0) + csF2 = wct.transform(cF2,sF2,args.alpha) + csF2 = csF2.to(next(wct.parameters()).device) Im2 = wct.d2(csF2) - sF1 = wct.e1(styleImg) - cF1 = wct.e1(Im2) - sF1 = sF1.data.cpu().squeeze(0) - cF1 = cF1.data.cpu().squeeze(0) - csF1 = wct.transform(cF1,sF1,csF,args.alpha) + sF1 = wct.encoder(styleImg, 'relu1_1') + cF1 = wct.encoder(Im2, 'relu1_1') + sF1 = sF1.cpu().squeeze(0) + cF1 = cF1.cpu().squeeze(0) + csF1 = wct.transform(cF1,sF1,args.alpha) + csF1 = csF1.to(next(wct.parameters()).device) Im1 = wct.d1(csF1) # save_image has this wired design to pad images with 4 pixels at default. - vutils.save_image(Im1.data.cpu().float(),os.path.join(args.outf,imname)) + vutils.save_image(Im1.cpu().float(),os.path.join(args.outf,imname)) return -avgTime = 0 -cImg = torch.Tensor() -sImg = torch.Tensor() -csF = torch.Tensor() -csF = Variable(csF) -if(args.cuda): - cImg = cImg.cuda(args.gpu) - sImg = sImg.cuda(args.gpu) - csF = csF.cuda(args.gpu) - wct.cuda(args.gpu) -for i,(contentImg,styleImg,imname) in enumerate(loader): - imname = imname[0] - print('Transferring ' + imname) - if (args.cuda): - contentImg = contentImg.cuda(args.gpu) - styleImg = styleImg.cuda(args.gpu) - cImg = Variable(contentImg,volatile=True) - sImg = Variable(styleImg,volatile=True) - start_time = time.time() - # WCT Style Transfer - styleTransfer(cImg,sImg,imname,csF) - end_time = time.time() - print('Elapsed time is: %f' % (end_time - start_time)) - avgTime += (end_time - start_time) +def main(): + wct = WCT(args) + if(args.cuda): + wct.cuda(args.gpu) -print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1))) + avgTime = 0 + for i,(contentImg,styleImg,imname) in enumerate(loader): + if(args.cuda): + contentImg = contentImg.cuda(args.gpu) + styleImg = styleImg.cuda(args.gpu) + imname = imname[0] + print('Transferring ' + imname) + if (args.cuda): + contentImg = contentImg.cuda(args.gpu) + styleImg = styleImg.cuda(args.gpu) + start_time = time.time() + # WCT Style Transfer + styleTransfer(wct, contentImg, styleImg, imname) + end_time = time.time() + print('Elapsed time is: %f' % (end_time - start_time)) + avgTime += (end_time - start_time) + + print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1))) + +if __name__ == '__main__': + with torch.no_grad(): + main() diff --git a/convertModels.py b/convertModels.py deleted file mode 100644 index 6dff8d6..0000000 --- a/convertModels.py +++ /dev/null @@ -1,133 +0,0 @@ -import numpy as np -from imageio import imread -from scipy.stats import describe -import torch -import torch.nn as nn -from torch.utils.serialization import load_lua -import torchvision.transforms as tvt - -import modelsNIPS -import vgg19_normalized -import vgg19_decoders - -CHECKPOINT_ENCODER_PY = 'models/vgg19_normalized.pth.tar' -LUA_CHECKPOINT_VGG = 'models/vgg_normalised_conv{}_1.t7' - -TEST_IMAGE = 'images/content/in4.jpg' -image_np = imread(TEST_IMAGE).astype(np.float32) - -ENCODERS = modelsNIPS.encoder1, modelsNIPS.encoder2, modelsNIPS.encoder3, modelsNIPS.encoder4, modelsNIPS.encoder5 -DECODERS = modelsNIPS.decoder1, modelsNIPS.decoder2, modelsNIPS.decoder3, modelsNIPS.decoder4, modelsNIPS.decoder5 - -# put image into [0, 1], but don't center or normalize like for other nets -trans = tvt.ToTensor() -image_pt = trans(image_np).unsqueeze(0) - -def convert_encoder(): - vgg_lua = [load_lua(LUA_CHECKPOINT_VGG.format(k)) for k in range(1, 6)] - vgg_lua_ = [e(vl) for e, vl in zip(ENCODERS, vgg_lua)] - - vgg_py = vgg19_normalized.VGG19_normalized() - - matching = { - vgg_py.blocks['conv1_1']: 2, - vgg_py.blocks['conv1_2']: 5, - - vgg_py.blocks['conv2_1']: 9, - vgg_py.blocks['conv2_2']: 12, - - vgg_py.blocks['conv3_1']: 16, - vgg_py.blocks['conv3_2']: 19, - vgg_py.blocks['conv3_3']: 22, - vgg_py.blocks['conv3_4']: 25, - - vgg_py.blocks['conv4_1']: 29, - vgg_py.blocks['conv4_2']: 32, - vgg_py.blocks['conv4_3']: 35, - vgg_py.blocks['conv4_4']: 38, - - vgg_py.blocks['conv5_1']: 42 - } - - for torch_conv, lua_conv_i in matching.items(): - weights = nn.Parameter(vgg_lua[4].get(lua_conv_i).weight.float()) - bias = nn.Parameter(vgg_lua[4].get(lua_conv_i).bias.float()) - torch_conv.load_state_dict({'weight': weights, 'bias': bias}) - - torch.save(vgg_py.state_dict(), CHECKPOINT_ENCODER_PY) - - for k in range(1, 6): - print(f'encoder {k}') - e_lua = vgg_lua_[k-1] - with torch.no_grad(): - al = e_lua(image_pt) - ap = vgg_py(image_pt, targets=f'relu{k}_1') - assert al.shape == ap.shape, (al.shape, ap.shape) - diff = np.abs((al - ap)) - print(describe(diff.flatten())) - print(np.percentile(diff, 99)) - print() - -def convert_decoder(K): - print(f'converting decoder from layer {K}') - decoderK_lua = load_lua(f'models/feature_invertor_conv{K}_1.t7') - decoderK_legacy = DECODERS[K-1](decoderK_lua) - decoderK_py = vgg19_decoders.DECODERS[K-1]() - - matching = { - 'conv5_1': -41, - - 'conv4_4': -37, - 'conv4_3': -34, - 'conv4_2': -31, - 'conv4_1': -28, - - 'conv3_4': -24, - 'conv3_3': -21, - 'conv3_2': -18, - 'conv3_1': -15, - - 'conv2_2': -11, - 'conv2_1': -8, - - 'conv1_2': -4, - 'conv1_1': -1 - - } - - for torch_conv, lua_conv_i in matching.items(): - if -lua_conv_i >= len(decoderK_lua): - continue - print(f' {torch_conv}') - weights = nn.Parameter(decoderK_lua.get(lua_conv_i).weight.float()) - bias = nn.Parameter(decoderK_lua.get(lua_conv_i).bias.float()) - decoderK_py.blocks[torch_conv].load_state_dict({'weight': weights, 'bias': bias}) - - torch.save(decoderK_py.state_dict(), f'models/vgg19_normalized_decoder{K}.pth.tar') - - encoder = vgg19_normalized.VGG19_normalized() - encoder.load_state_dict(torch.load(CHECKPOINT_ENCODER_PY)) - - print(f'testing encoding/decoding at layer {K}') - - with torch.no_grad(): - features = encoder(image_pt, targets=f'relu{K}_1') - rgb_legacy = decoderK_legacy(features) - rgb_py = decoderK_py(features) - assert rgb_legacy.shape == rgb_py.shape, (rgb_legacy.shape, rgb_py.shape) - diff = np.abs((rgb_legacy - rgb_py).numpy()) - print(describe(diff.flatten())) - print(np.percentile(diff, 99)) - print() - -def main(): - convert_encoder() - - for K in range(1, 6): - convert_decoder(K) - - print('DONE') - - -if __name__ == '__main__': - main() diff --git a/modelsNIPS.py b/modelsNIPS.py deleted file mode 100644 index 5c6ea2a..0000000 --- a/modelsNIPS.py +++ /dev/null @@ -1,788 +0,0 @@ -import torch.nn as nn -import torch - -class encoder1(nn.Module): - def __init__(self,vgg1): - super(encoder1,self).__init__() - # dissemble vgg2 and decoder2 layer by layer - # then resemble a new encoder-decoder network - # 224 x 224 - self.conv1 = nn.Conv2d(3,3,1,1,0) - self.conv1.weight = torch.nn.Parameter(vgg1.get(0).weight.float()) - self.conv1.bias = torch.nn.Parameter(vgg1.get(0).bias.float()) - # 224 x 224 - self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) - # 226 x 226 - self.conv2 = nn.Conv2d(3,64,3,1,0) - self.conv2.weight = torch.nn.Parameter(vgg1.get(2).weight.float()) - self.conv2.bias = torch.nn.Parameter(vgg1.get(2).bias.float()) - - self.relu = nn.ReLU(inplace=True) - # 224 x 224 - def forward(self,x): - out = self.conv1(x) - out = self.reflecPad1(out) - out = self.conv2(out) - out = self.relu(out) - return out - -class decoder1(nn.Module): - def __init__(self,d1): - super(decoder1,self).__init__() - self.reflecPad2 = nn.ReflectionPad2d((1,1,1,1)) - # 226 x 226 - self.conv3 = nn.Conv2d(64,3,3,1,0) - self.conv3.weight = torch.nn.Parameter(d1.get(1).weight.float()) - self.conv3.bias = torch.nn.Parameter(d1.get(1).bias.float()) - # 224 x 224 - - def forward(self,x): - out = self.reflecPad2(x) - out = self.conv3(out) - return out - - -class encoder2(nn.Module): - def __init__(self,vgg): - super(encoder2,self).__init__() - # vgg - # 224 x 224 - self.conv1 = nn.Conv2d(3,3,1,1,0) - self.conv1.weight = torch.nn.Parameter(vgg.get(0).weight.float()) - self.conv1.bias = torch.nn.Parameter(vgg.get(0).bias.float()) - self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) - # 226 x 226 - - self.conv2 = nn.Conv2d(3,64,3,1,0) - self.conv2.weight = torch.nn.Parameter(vgg.get(2).weight.float()) - self.conv2.bias = torch.nn.Parameter(vgg.get(2).bias.float()) - self.relu2 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) - self.conv3 = nn.Conv2d(64,64,3,1,0) - self.conv3.weight = torch.nn.Parameter(vgg.get(5).weight.float()) - self.conv3.bias = torch.nn.Parameter(vgg.get(5).bias.float()) - self.relu3 = nn.ReLU(inplace=True) - # 224 x 224 - - self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 112 x 112 - - self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) - self.conv4 = nn.Conv2d(64,128,3,1,0) - self.conv4.weight = torch.nn.Parameter(vgg.get(9).weight.float()) - self.conv4.bias = torch.nn.Parameter(vgg.get(9).bias.float()) - self.relu4 = nn.ReLU(inplace=True) - # 112 x 112 - - def forward(self,x): - out = self.conv1(x) - out = self.reflecPad1(out) - out = self.conv2(out) - out = self.relu2(out) - out = self.reflecPad3(out) - out = self.conv3(out) - pool = self.relu3(out) - out,pool_idx = self.maxPool(pool) - out = self.reflecPad4(out) - out = self.conv4(out) - out = self.relu4(out) - return out - -class decoder2(nn.Module): - def __init__(self,d): - super(decoder2,self).__init__() - # decoder - self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) - self.conv5 = nn.Conv2d(128,64,3,1,0) - self.conv5.weight = torch.nn.Parameter(d.get(1).weight.float()) - self.conv5.bias = torch.nn.Parameter(d.get(1).bias.float()) - self.relu5 = nn.ReLU(inplace=True) - # 112 x 112 - - self.unpool = nn.UpsamplingNearest2d(scale_factor=2) - # 224 x 224 - - self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) - self.conv6 = nn.Conv2d(64,64,3,1,0) - self.conv6.weight = torch.nn.Parameter(d.get(5).weight.float()) - self.conv6.bias = torch.nn.Parameter(d.get(5).bias.float()) - self.relu6 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) - self.conv7 = nn.Conv2d(64,3,3,1,0) - self.conv7.weight = torch.nn.Parameter(d.get(8).weight.float()) - self.conv7.bias = torch.nn.Parameter(d.get(8).bias.float()) - - def forward(self,x): - out = self.reflecPad5(x) - out = self.conv5(out) - out = self.relu5(out) - out = self.unpool(out) - out = self.reflecPad6(out) - out = self.conv6(out) - out = self.relu6(out) - out = self.reflecPad7(out) - out = self.conv7(out) - return out - -class encoder3(nn.Module): - def __init__(self,vgg): - super(encoder3,self).__init__() - # vgg - # 224 x 224 - self.conv1 = nn.Conv2d(3,3,1,1,0) - self.conv1.weight = torch.nn.Parameter(vgg.get(0).weight.float()) - self.conv1.bias = torch.nn.Parameter(vgg.get(0).bias.float()) - self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) - # 226 x 226 - - self.conv2 = nn.Conv2d(3,64,3,1,0) - self.conv2.weight = torch.nn.Parameter(vgg.get(2).weight.float()) - self.conv2.bias = torch.nn.Parameter(vgg.get(2).bias.float()) - self.relu2 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) - self.conv3 = nn.Conv2d(64,64,3,1,0) - self.conv3.weight = torch.nn.Parameter(vgg.get(5).weight.float()) - self.conv3.bias = torch.nn.Parameter(vgg.get(5).bias.float()) - self.relu3 = nn.ReLU(inplace=True) - # 224 x 224 - - self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 112 x 112 - - self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) - self.conv4 = nn.Conv2d(64,128,3,1,0) - self.conv4.weight = torch.nn.Parameter(vgg.get(9).weight.float()) - self.conv4.bias = torch.nn.Parameter(vgg.get(9).bias.float()) - self.relu4 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) - self.conv5 = nn.Conv2d(128,128,3,1,0) - self.conv5.weight = torch.nn.Parameter(vgg.get(12).weight.float()) - self.conv5.bias = torch.nn.Parameter(vgg.get(12).bias.float()) - self.relu5 = nn.ReLU(inplace=True) - # 112 x 112 - - self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 56 x 56 - - self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) - self.conv6 = nn.Conv2d(128,256,3,1,0) - self.conv6.weight = torch.nn.Parameter(vgg.get(16).weight.float()) - self.conv6.bias = torch.nn.Parameter(vgg.get(16).bias.float()) - self.relu6 = nn.ReLU(inplace=True) - # 56 x 56 - def forward(self,x): - out = self.conv1(x) - out = self.reflecPad1(out) - out = self.conv2(out) - out = self.relu2(out) - out = self.reflecPad3(out) - out = self.conv3(out) - pool1 = self.relu3(out) - out,pool_idx = self.maxPool(pool1) - out = self.reflecPad4(out) - out = self.conv4(out) - out = self.relu4(out) - out = self.reflecPad5(out) - out = self.conv5(out) - pool2 = self.relu5(out) - out,pool_idx2 = self.maxPool2(pool2) - out = self.reflecPad6(out) - out = self.conv6(out) - out = self.relu6(out) - return out - -class decoder3(nn.Module): - def __init__(self,d): - super(decoder3,self).__init__() - # decoder - self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) - self.conv7 = nn.Conv2d(256,128,3,1,0) - self.conv7.weight = torch.nn.Parameter(d.get(1).weight.float()) - self.conv7.bias = torch.nn.Parameter(d.get(1).bias.float()) - self.relu7 = nn.ReLU(inplace=True) - # 56 x 56 - - self.unpool = nn.UpsamplingNearest2d(scale_factor=2) - # 112 x 112 - - self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) - self.conv8 = nn.Conv2d(128,128,3,1,0) - self.conv8.weight = torch.nn.Parameter(d.get(5).weight.float()) - self.conv8.bias = torch.nn.Parameter(d.get(5).bias.float()) - self.relu8 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) - self.conv9 = nn.Conv2d(128,64,3,1,0) - self.conv9.weight = torch.nn.Parameter(d.get(8).weight.float()) - self.conv9.bias = torch.nn.Parameter(d.get(8).bias.float()) - self.relu9 = nn.ReLU(inplace=True) - - self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) - # 224 x 224 - - self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) - self.conv10 = nn.Conv2d(64,64,3,1,0) - self.conv10.weight = torch.nn.Parameter(d.get(12).weight.float()) - self.conv10.bias = torch.nn.Parameter(d.get(12).bias.float()) - self.relu10 = nn.ReLU(inplace=True) - - self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) - self.conv11 = nn.Conv2d(64,3,3,1,0) - self.conv11.weight = torch.nn.Parameter(d.get(15).weight.float()) - self.conv11.bias = torch.nn.Parameter(d.get(15).bias.float()) - - def forward(self,x): - out = self.reflecPad7(x) - out = self.conv7(out) - out = self.relu7(out) - out = self.unpool(out) - out = self.reflecPad8(out) - out = self.conv8(out) - out = self.relu8(out) - out = self.reflecPad9(out) - out = self.conv9(out) - out = self.relu9(out) - out = self.unpool2(out) - out = self.reflecPad10(out) - out = self.conv10(out) - out = self.relu10(out) - out = self.reflecPad11(out) - out = self.conv11(out) - return out - -class encoder4(nn.Module): - def __init__(self,vgg): - super(encoder4,self).__init__() - # vgg - # 224 x 224 - self.conv1 = nn.Conv2d(3,3,1,1,0) - self.conv1.weight = torch.nn.Parameter(vgg.get(0).weight.float()) - self.conv1.bias = torch.nn.Parameter(vgg.get(0).bias.float()) - self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) - # 226 x 226 - - self.conv2 = nn.Conv2d(3,64,3,1,0) - self.conv2.weight = torch.nn.Parameter(vgg.get(2).weight.float()) - self.conv2.bias = torch.nn.Parameter(vgg.get(2).bias.float()) - self.relu2 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) - self.conv3 = nn.Conv2d(64,64,3,1,0) - self.conv3.weight = torch.nn.Parameter(vgg.get(5).weight.float()) - self.conv3.bias = torch.nn.Parameter(vgg.get(5).bias.float()) - self.relu3 = nn.ReLU(inplace=True) - # 224 x 224 - - self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 112 x 112 - - self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) - self.conv4 = nn.Conv2d(64,128,3,1,0) - self.conv4.weight = torch.nn.Parameter(vgg.get(9).weight.float()) - self.conv4.bias = torch.nn.Parameter(vgg.get(9).bias.float()) - self.relu4 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) - self.conv5 = nn.Conv2d(128,128,3,1,0) - self.conv5.weight = torch.nn.Parameter(vgg.get(12).weight.float()) - self.conv5.bias = torch.nn.Parameter(vgg.get(12).bias.float()) - self.relu5 = nn.ReLU(inplace=True) - # 112 x 112 - - self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 56 x 56 - - self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) - self.conv6 = nn.Conv2d(128,256,3,1,0) - self.conv6.weight = torch.nn.Parameter(vgg.get(16).weight.float()) - self.conv6.bias = torch.nn.Parameter(vgg.get(16).bias.float()) - self.relu6 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) - self.conv7 = nn.Conv2d(256,256,3,1,0) - self.conv7.weight = torch.nn.Parameter(vgg.get(19).weight.float()) - self.conv7.bias = torch.nn.Parameter(vgg.get(19).bias.float()) - self.relu7 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) - self.conv8 = nn.Conv2d(256,256,3,1,0) - self.conv8.weight = torch.nn.Parameter(vgg.get(22).weight.float()) - self.conv8.bias = torch.nn.Parameter(vgg.get(22).bias.float()) - self.relu8 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) - self.conv9 = nn.Conv2d(256,256,3,1,0) - self.conv9.weight = torch.nn.Parameter(vgg.get(25).weight.float()) - self.conv9.bias = torch.nn.Parameter(vgg.get(25).bias.float()) - self.relu9 = nn.ReLU(inplace=True) - # 56 x 56 - - self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 28 x 28 - - self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) - self.conv10 = nn.Conv2d(256,512,3,1,0) - self.conv10.weight = torch.nn.Parameter(vgg.get(29).weight.float()) - self.conv10.bias = torch.nn.Parameter(vgg.get(29).bias.float()) - self.relu10 = nn.ReLU(inplace=True) - # 28 x 28 - def forward(self,x): - out = self.conv1(x) - out = self.reflecPad1(out) - out = self.conv2(out) - out = self.relu2(out) - out = self.reflecPad3(out) - out = self.conv3(out) - pool1 = self.relu3(out) - out,pool_idx = self.maxPool(pool1) - out = self.reflecPad4(out) - out = self.conv4(out) - out = self.relu4(out) - out = self.reflecPad5(out) - out = self.conv5(out) - pool2 = self.relu5(out) - out,pool_idx2 = self.maxPool2(pool2) - out = self.reflecPad6(out) - out = self.conv6(out) - out = self.relu6(out) - out = self.reflecPad7(out) - out = self.conv7(out) - out = self.relu7(out) - out = self.reflecPad8(out) - out = self.conv8(out) - out = self.relu8(out) - out = self.reflecPad9(out) - out = self.conv9(out) - pool3 = self.relu9(out) - out,pool_idx3 = self.maxPool3(pool3) - out = self.reflecPad10(out) - out = self.conv10(out) - out = self.relu10(out) - return out - -class decoder4(nn.Module): - def __init__(self,d): - super(decoder4,self).__init__() - # decoder - self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) - self.conv11 = nn.Conv2d(512,256,3,1,0) - self.conv11.weight = torch.nn.Parameter(d.get(1).weight.float()) - self.conv11.bias = torch.nn.Parameter(d.get(1).bias.float()) - self.relu11 = nn.ReLU(inplace=True) - # 28 x 28 - - self.unpool = nn.UpsamplingNearest2d(scale_factor=2) - # 56 x 56 - - self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) - self.conv12 = nn.Conv2d(256,256,3,1,0) - self.conv12.weight = torch.nn.Parameter(d.get(5).weight.float()) - self.conv12.bias = torch.nn.Parameter(d.get(5).bias.float()) - self.relu12 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) - self.conv13 = nn.Conv2d(256,256,3,1,0) - self.conv13.weight = torch.nn.Parameter(d.get(8).weight.float()) - self.conv13.bias = torch.nn.Parameter(d.get(8).bias.float()) - self.relu13 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) - self.conv14 = nn.Conv2d(256,256,3,1,0) - self.conv14.weight = torch.nn.Parameter(d.get(11).weight.float()) - self.conv14.bias = torch.nn.Parameter(d.get(11).bias.float()) - self.relu14 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) - self.conv15 = nn.Conv2d(256,128,3,1,0) - self.conv15.weight = torch.nn.Parameter(d.get(14).weight.float()) - self.conv15.bias = torch.nn.Parameter(d.get(14).bias.float()) - self.relu15 = nn.ReLU(inplace=True) - # 56 x 56 - - self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) - # 112 x 112 - - self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) - self.conv16 = nn.Conv2d(128,128,3,1,0) - self.conv16.weight = torch.nn.Parameter(d.get(18).weight.float()) - self.conv16.bias = torch.nn.Parameter(d.get(18).bias.float()) - self.relu16 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) - self.conv17 = nn.Conv2d(128,64,3,1,0) - self.conv17.weight = torch.nn.Parameter(d.get(21).weight.float()) - self.conv17.bias = torch.nn.Parameter(d.get(21).bias.float()) - self.relu17 = nn.ReLU(inplace=True) - # 112 x 112 - - self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) - # 224 x 224 - - self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) - self.conv18 = nn.Conv2d(64,64,3,1,0) - self.conv18.weight = torch.nn.Parameter(d.get(25).weight.float()) - self.conv18.bias = torch.nn.Parameter(d.get(25).bias.float()) - self.relu18 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) - self.conv19 = nn.Conv2d(64,3,3,1,0) - self.conv19.weight = torch.nn.Parameter(d.get(28).weight.float()) - self.conv19.bias = torch.nn.Parameter(d.get(28).bias.float()) - - - - def forward(self,x): - # decoder - out = self.reflecPad11(x) - out = self.conv11(out) - out = self.relu11(out) - out = self.unpool(out) - out = self.reflecPad12(out) - out = self.conv12(out) - - out = self.relu12(out) - out = self.reflecPad13(out) - out = self.conv13(out) - out = self.relu13(out) - out = self.reflecPad14(out) - out = self.conv14(out) - out = self.relu14(out) - out = self.reflecPad15(out) - out = self.conv15(out) - out = self.relu15(out) - out = self.unpool2(out) - out = self.reflecPad16(out) - out = self.conv16(out) - out = self.relu16(out) - out = self.reflecPad17(out) - out = self.conv17(out) - out = self.relu17(out) - out = self.unpool3(out) - out = self.reflecPad18(out) - out = self.conv18(out) - out = self.relu18(out) - out = self.reflecPad19(out) - out = self.conv19(out) - return out -class encoder5(nn.Module): - def __init__(self,vgg): - super(encoder5,self).__init__() - # vgg - # 224 x 224 - self.conv1 = nn.Conv2d(3,3,1,1,0) - self.conv1.weight = torch.nn.Parameter(vgg.get(0).weight.float()) - self.conv1.bias = torch.nn.Parameter(vgg.get(0).bias.float()) - self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) - # 226 x 226 - - self.conv2 = nn.Conv2d(3,64,3,1,0) - self.conv2.weight = torch.nn.Parameter(vgg.get(2).weight.float()) - self.conv2.bias = torch.nn.Parameter(vgg.get(2).bias.float()) - self.relu2 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) - self.conv3 = nn.Conv2d(64,64,3,1,0) - self.conv3.weight = torch.nn.Parameter(vgg.get(5).weight.float()) - self.conv3.bias = torch.nn.Parameter(vgg.get(5).bias.float()) - self.relu3 = nn.ReLU(inplace=True) - # 224 x 224 - - self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 112 x 112 - - self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) - self.conv4 = nn.Conv2d(64,128,3,1,0) - self.conv4.weight = torch.nn.Parameter(vgg.get(9).weight.float()) - self.conv4.bias = torch.nn.Parameter(vgg.get(9).bias.float()) - self.relu4 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) - self.conv5 = nn.Conv2d(128,128,3,1,0) - self.conv5.weight = torch.nn.Parameter(vgg.get(12).weight.float()) - self.conv5.bias = torch.nn.Parameter(vgg.get(12).bias.float()) - self.relu5 = nn.ReLU(inplace=True) - # 112 x 112 - - self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 56 x 56 - - self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) - self.conv6 = nn.Conv2d(128,256,3,1,0) - self.conv6.weight = torch.nn.Parameter(vgg.get(16).weight.float()) - self.conv6.bias = torch.nn.Parameter(vgg.get(16).bias.float()) - self.relu6 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) - self.conv7 = nn.Conv2d(256,256,3,1,0) - self.conv7.weight = torch.nn.Parameter(vgg.get(19).weight.float()) - self.conv7.bias = torch.nn.Parameter(vgg.get(19).bias.float()) - self.relu7 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) - self.conv8 = nn.Conv2d(256,256,3,1,0) - self.conv8.weight = torch.nn.Parameter(vgg.get(22).weight.float()) - self.conv8.bias = torch.nn.Parameter(vgg.get(22).bias.float()) - self.relu8 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) - self.conv9 = nn.Conv2d(256,256,3,1,0) - self.conv9.weight = torch.nn.Parameter(vgg.get(25).weight.float()) - self.conv9.bias = torch.nn.Parameter(vgg.get(25).bias.float()) - self.relu9 = nn.ReLU(inplace=True) - # 56 x 56 - - self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 28 x 28 - - self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) - self.conv10 = nn.Conv2d(256,512,3,1,0) - self.conv10.weight = torch.nn.Parameter(vgg.get(29).weight.float()) - self.conv10.bias = torch.nn.Parameter(vgg.get(29).bias.float()) - self.relu10 = nn.ReLU(inplace=True) - # 28 x 28 - - self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) - self.conv11 = nn.Conv2d(512,512,3,1,0) - self.conv11.weight = torch.nn.Parameter(vgg.get(32).weight.float()) - self.conv11.bias = torch.nn.Parameter(vgg.get(32).bias.float()) - self.relu11 = nn.ReLU(inplace=True) - # 28 x 28 - - self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) - self.conv12 = nn.Conv2d(512,512,3,1,0) - self.conv12.weight = torch.nn.Parameter(vgg.get(35).weight.float()) - self.conv12.bias = torch.nn.Parameter(vgg.get(35).bias.float()) - self.relu12 = nn.ReLU(inplace=True) - # 28 x 28 - - self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) - self.conv13 = nn.Conv2d(512,512,3,1,0) - self.conv13.weight = torch.nn.Parameter(vgg.get(38).weight.float()) - self.conv13.bias = torch.nn.Parameter(vgg.get(38).bias.float()) - self.relu13 = nn.ReLU(inplace=True) - # 28 x 28 - - self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 14 x 14 - - self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) - self.conv14 = nn.Conv2d(512,512,3,1,0) - self.conv14.weight = torch.nn.Parameter(vgg.get(42).weight.float()) - self.conv14.bias = torch.nn.Parameter(vgg.get(42).bias.float()) - self.relu14 = nn.ReLU(inplace=True) - # 14 x 14 - def forward(self,x): - out = self.conv1(x) - out = self.reflecPad1(out) - out = self.conv2(out) - out = self.relu2(out) - out = self.reflecPad3(out) - out = self.conv3(out) - out = self.relu3(out) - out,pool_idx = self.maxPool(out) - out = self.reflecPad4(out) - out = self.conv4(out) - out = self.relu4(out) - out = self.reflecPad5(out) - out = self.conv5(out) - out = self.relu5(out) - out,pool_idx2 = self.maxPool2(out) - out = self.reflecPad6(out) - out = self.conv6(out) - out = self.relu6(out) - out = self.reflecPad7(out) - out = self.conv7(out) - out = self.relu7(out) - out = self.reflecPad8(out) - out = self.conv8(out) - out = self.relu8(out) - out = self.reflecPad9(out) - out = self.conv9(out) - out = self.relu9(out) - out,pool_idx3 = self.maxPool3(out) - out = self.reflecPad10(out) - out = self.conv10(out) - out = self.relu10(out) - out = self.reflecPad11(out) - out = self.conv11(out) - out = self.relu11(out) - out = self.reflecPad12(out) - out = self.conv12(out) - out = self.relu12(out) - out = self.reflecPad13(out) - out = self.conv13(out) - out = self.relu13(out) - out,pool_idx4 = self.maxPool4(out) - out = self.reflecPad14(out) - out = self.conv14(out) - out = self.relu14(out) - return out - - -class decoder5(nn.Module): - def __init__(self,d): - super(decoder5,self).__init__() - - # decoder - self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) - self.conv15 = nn.Conv2d(512,512,3,1,0) - self.conv15.weight = torch.nn.Parameter(d.get(1).weight.float()) - self.conv15.bias = torch.nn.Parameter(d.get(1).bias.float()) - self.relu15 = nn.ReLU(inplace=True) - - self.unpool = nn.UpsamplingNearest2d(scale_factor=2) - # 28 x 28 - - self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) - self.conv16 = nn.Conv2d(512,512,3,1,0) - self.conv16.weight = torch.nn.Parameter(d.get(5).weight.float()) - self.conv16.bias = torch.nn.Parameter(d.get(5).bias.float()) - self.relu16 = nn.ReLU(inplace=True) - # 28 x 28 - - self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) - self.conv17 = nn.Conv2d(512,512,3,1,0) - self.conv17.weight = torch.nn.Parameter(d.get(8).weight.float()) - self.conv17.bias = torch.nn.Parameter(d.get(8).bias.float()) - self.relu17 = nn.ReLU(inplace=True) - # 28 x 28 - - self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) - self.conv18 = nn.Conv2d(512,512,3,1,0) - self.conv18.weight = torch.nn.Parameter(d.get(11).weight.float()) - self.conv18.bias = torch.nn.Parameter(d.get(11).bias.float()) - self.relu18 = nn.ReLU(inplace=True) - # 28 x 28 - - self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) - self.conv19 = nn.Conv2d(512,256,3,1,0) - self.conv19.weight = torch.nn.Parameter(d.get(14).weight.float()) - self.conv19.bias = torch.nn.Parameter(d.get(14).bias.float()) - self.relu19 = nn.ReLU(inplace=True) - # 28 x 28 - - self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) - # 56 x 56 - - self.reflecPad20 = nn.ReflectionPad2d((1,1,1,1)) - self.conv20 = nn.Conv2d(256,256,3,1,0) - self.conv20.weight = torch.nn.Parameter(d.get(18).weight.float()) - self.conv20.bias = torch.nn.Parameter(d.get(18).bias.float()) - self.relu20 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad21 = nn.ReflectionPad2d((1,1,1,1)) - self.conv21 = nn.Conv2d(256,256,3,1,0) - self.conv21.weight = torch.nn.Parameter(d.get(21).weight.float()) - self.conv21.bias = torch.nn.Parameter(d.get(21).bias.float()) - self.relu21 = nn.ReLU(inplace=True) - - self.reflecPad22 = nn.ReflectionPad2d((1,1,1,1)) - self.conv22 = nn.Conv2d(256,256,3,1,0) - self.conv22.weight = torch.nn.Parameter(d.get(24).weight.float()) - self.conv22.bias = torch.nn.Parameter(d.get(24).bias.float()) - self.relu22 = nn.ReLU(inplace=True) - - self.reflecPad23 = nn.ReflectionPad2d((1,1,1,1)) - self.conv23 = nn.Conv2d(256,128,3,1,0) - self.conv23.weight = torch.nn.Parameter(d.get(27).weight.float()) - self.conv23.bias = torch.nn.Parameter(d.get(27).bias.float()) - self.relu23 = nn.ReLU(inplace=True) - - self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) - # 112 X 112 - - self.reflecPad24 = nn.ReflectionPad2d((1,1,1,1)) - self.conv24 = nn.Conv2d(128,128,3,1,0) - self.conv24.weight = torch.nn.Parameter(d.get(31).weight.float()) - self.conv24.bias = torch.nn.Parameter(d.get(31).bias.float()) - self.relu24 = nn.ReLU(inplace=True) - - self.reflecPad25 = nn.ReflectionPad2d((1,1,1,1)) - self.conv25 = nn.Conv2d(128,64,3,1,0) - self.conv25.weight = torch.nn.Parameter(d.get(34).weight.float()) - self.conv25.bias = torch.nn.Parameter(d.get(34).bias.float()) - self.relu25 = nn.ReLU(inplace=True) - - self.unpool4 = nn.UpsamplingNearest2d(scale_factor=2) - - self.reflecPad26 = nn.ReflectionPad2d((1,1,1,1)) - self.conv26 = nn.Conv2d(64,64,3,1,0) - self.conv26.weight = torch.nn.Parameter(d.get(38).weight.float()) - self.conv26.bias = torch.nn.Parameter(d.get(38).bias.float()) - self.relu26 = nn.ReLU(inplace=True) - - self.reflecPad27 = nn.ReflectionPad2d((1,1,1,1)) - self.conv27 = nn.Conv2d(64,3,3,1,0) - self.conv27.weight = torch.nn.Parameter(d.get(41).weight.float()) - self.conv27.bias = torch.nn.Parameter(d.get(41).bias.float()) - - def forward(self,x): - # decoder - out = self.reflecPad15(x) - out = self.conv15(out) - out = self.relu15(out) - out = self.unpool(out) - out = self.reflecPad16(out) - out = self.conv16(out) - out = self.relu16(out) - out = self.reflecPad17(out) - out = self.conv17(out) - out = self.relu17(out) - out = self.reflecPad18(out) - out = self.conv18(out) - out = self.relu18(out) - out = self.reflecPad19(out) - out = self.conv19(out) - out = self.relu19(out) - out = self.unpool2(out) - out = self.reflecPad20(out) - out = self.conv20(out) - out = self.relu20(out) - out = self.reflecPad21(out) - out = self.conv21(out) - out = self.relu21(out) - out = self.reflecPad22(out) - out = self.conv22(out) - out = self.relu22(out) - out = self.reflecPad23(out) - out = self.conv23(out) - out = self.relu23(out) - out = self.unpool3(out) - out = self.reflecPad24(out) - out = self.conv24(out) - out = self.relu24(out) - out = self.reflecPad25(out) - out = self.conv25(out) - out = self.relu25(out) - out = self.unpool4(out) - out = self.reflecPad26(out) - out = self.conv26(out) - out = self.relu26(out) - out = self.reflecPad27(out) - out = self.conv27(out) - return out diff --git a/modelsNIPS.pyc b/modelsNIPS.pyc deleted file mode 100644 index 7d59140..0000000 Binary files a/modelsNIPS.pyc and /dev/null differ diff --git a/samples/in1.jpg b/samples/in1.jpg index acf03d1..b1dd50b 100644 Binary files a/samples/in1.jpg and b/samples/in1.jpg differ diff --git a/samples/in2.jpg b/samples/in2.jpg index 3e3bc28..1170888 100644 Binary files a/samples/in2.jpg and b/samples/in2.jpg differ diff --git a/samples/in3.jpg b/samples/in3.jpg index 3bc21af..75ca7cd 100644 Binary files a/samples/in3.jpg and b/samples/in3.jpg differ diff --git a/samples/in4.jpg b/samples/in4.jpg index 8e277d7..223ccbe 100644 Binary files a/samples/in4.jpg and b/samples/in4.jpg differ diff --git a/util.py b/util.py index e9a7384..c4a272f 100644 --- a/util.py +++ b/util.py @@ -1,14 +1,13 @@ from __future__ import division import torch -from torch.utils.serialization import load_lua import torchvision.transforms as transforms import numpy as np import argparse import time import os from PIL import Image -from modelsNIPS import decoder1,decoder2,decoder3,decoder4,decoder5 -from modelsNIPS import encoder1,encoder2,encoder3,encoder4,encoder5 +from vgg19_decoders import VGG19Decoder1, VGG19Decoder2, VGG19Decoder3, VGG19Decoder4, VGG19Decoder5 +from vgg19_normalized import VGG19_normalized import torch.nn as nn @@ -17,28 +16,19 @@ class WCT(nn.Module): def __init__(self,args): super(WCT, self).__init__() # load pre-trained network - vgg1 = load_lua(args.vgg1) - decoder1_torch = load_lua(args.decoder1) - vgg2 = load_lua(args.vgg2) - decoder2_torch = load_lua(args.decoder2) - vgg3 = load_lua(args.vgg3) - decoder3_torch = load_lua(args.decoder3) - vgg4 = load_lua(args.vgg4) - decoder4_torch = load_lua(args.decoder4) - vgg5 = load_lua(args.vgg5) - decoder5_torch = load_lua(args.decoder5) + self.encoder = VGG19_normalized() + self.encoder.load_state_dict(torch.load(args.encoder)) - - self.e1 = encoder1(vgg1) - self.d1 = decoder1(decoder1_torch) - self.e2 = encoder2(vgg2) - self.d2 = decoder2(decoder2_torch) - self.e3 = encoder3(vgg3) - self.d3 = decoder3(decoder3_torch) - self.e4 = encoder4(vgg4) - self.d4 = decoder4(decoder4_torch) - self.e5 = encoder5(vgg5) - self.d5 = decoder5(decoder5_torch) + self.d1 = VGG19Decoder1() + self.d1.load_state_dict(torch.load(args.decoder1)) + self.d2 = VGG19Decoder2() + self.d2.load_state_dict(torch.load(args.decoder2)) + self.d3 = VGG19Decoder3() + self.d3.load_state_dict(torch.load(args.decoder3)) + self.d4 = VGG19Decoder4() + self.d4.load_state_dict(torch.load(args.decoder4)) + self.d5 = VGG19Decoder5() + self.d5.load_state_dict(torch.load(args.decoder5)) def whiten_and_color(self,cF,sF): cFSize = cF.size() @@ -77,7 +67,7 @@ def whiten_and_color(self,cF,sF): targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature) return targetFeature - def transform(self,cF,sF,csF,alpha): + def transform(self,cF,sF,alpha): cF = cF.double() sF = sF.double() C,W,H = cF.size(0),cF.size(1),cF.size(2) @@ -87,7 +77,6 @@ def transform(self,cF,sF,csF,alpha): targetFeature = self.whiten_and_color(cFView,sFView) targetFeature = targetFeature.view_as(cF) - ccsF = alpha * targetFeature + (1.0 - alpha) * cF - ccsF = ccsF.float().unsqueeze(0) - csF.data.resize_(ccsF.size()).copy_(ccsF) + csF = alpha * targetFeature + (1.0 - alpha) * cF + csF = csF.float().unsqueeze(0) return csF