Skip to content

Commit

Permalink
migrate to pytorch 1.0.
Browse files Browse the repository at this point in the history
use torch.no_grad() instead of volatile variables.
update example images to also use relu5_1 stylization.
Point models download to new file.
Remove model conversion script.
add .gitignore
  • Loading branch information
black-puppydog committed Dec 18, 2018
1 parent 452846f commit bb093b9
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 1,014 deletions.
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
models
models.zip
models_pytorch.zip

# python artifacts
__pychache__
*.pyc

# Temporary Files
*~
*.swp
*.swo
2 changes: 1 addition & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Official Torch implementation can be found [here](https://github.com/Yijunmaveri
## Prerequisites
- [Pytorch](http://pytorch.org/)
- [torchvision](https://github.com/pytorch/vision)
- Pretrained encoder and decoder [models](https://drive.google.com/file/d/1M5KBPfqrIUZqrBZf78CIxLrMUT4lD4t9/view?usp=sharing) for image reconstruction only (download and uncompress them under models/)
- Pretrained encoder and decoder [models](http://pascal.inrialpes.fr/data2/archetypal_style/models_pytorch.zip) for image reconstruction only (download and uncompress them under models/)
- CUDA + CuDNN

## Prepare images
Expand Down
128 changes: 64 additions & 64 deletions WCT.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,18 @@
from Loader import Dataset
from util import *
import scipy.misc
from torch.utils.serialization import load_lua
import time

parser = argparse.ArgumentParser(description='WCT Pytorch')
parser.add_argument('--contentPath',default='images/content',help='path to train')
parser.add_argument('--stylePath',default='images/style',help='path to train')
parser.add_argument('--workers', default=2, type=int, metavar='N',help='number of data loading workers (default: 4)')
parser.add_argument('--vgg1', default='models/vgg_normalised_conv1_1.t7', help='Path to the VGG conv1_1')
parser.add_argument('--vgg2', default='models/vgg_normalised_conv2_1.t7', help='Path to the VGG conv2_1')
parser.add_argument('--vgg3', default='models/vgg_normalised_conv3_1.t7', help='Path to the VGG conv3_1')
parser.add_argument('--vgg4', default='models/vgg_normalised_conv4_1.t7', help='Path to the VGG conv4_1')
parser.add_argument('--vgg5', default='models/vgg_normalised_conv5_1.t7', help='Path to the VGG conv5_1')
parser.add_argument('--decoder5', default='models/feature_invertor_conv5_1.t7', help='Path to the decoder5')
parser.add_argument('--decoder4', default='models/feature_invertor_conv4_1.t7', help='Path to the decoder4')
parser.add_argument('--decoder3', default='models/feature_invertor_conv3_1.t7', help='Path to the decoder3')
parser.add_argument('--decoder2', default='models/feature_invertor_conv2_1.t7', help='Path to the decoder2')
parser.add_argument('--decoder1', default='models/feature_invertor_conv1_1.t7', help='Path to the decoder1')
parser.add_argument('--encoder', default='models/vgg19_normalized.pth.tar', help='Path to the VGG conv1_1')
parser.add_argument('--decoder5', default='models/vgg19_normalized_decoder5.pth.tar', help='Path to the decoder5')
parser.add_argument('--decoder4', default='models/vgg19_normalized_decoder4.pth.tar', help='Path to the decoder4')
parser.add_argument('--decoder3', default='models/vgg19_normalized_decoder3.pth.tar', help='Path to the decoder3')
parser.add_argument('--decoder2', default='models/vgg19_normalized_decoder2.pth.tar', help='Path to the decoder2')
parser.add_argument('--decoder1', default='models/vgg19_normalized_decoder1.pth.tar', help='Path to the decoder1')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--fineSize', type=int, default=512, help='resize image to fineSize x fineSize,leave it to 0 if not resize')
Expand All @@ -45,70 +40,75 @@
batch_size=1,
shuffle=False)

wct = WCT(args)
def styleTransfer(contentImg,styleImg,imname,csF):
def styleTransfer(wct, contentImg, styleImg, imname):

sF5 = wct.e5(styleImg)
cF5 = wct.e5(contentImg)
sF5 = sF5.data.cpu().squeeze(0)
cF5 = cF5.data.cpu().squeeze(0)
csF5 = wct.transform(cF5,sF5,csF,args.alpha)
sF5 = wct.encoder(styleImg, 'relu5_1')
cF5 = wct.encoder(contentImg, 'relu5_1')
sF5 = sF5.cpu().squeeze(0)
cF5 = cF5.cpu().squeeze(0)
csF5 = wct.transform(cF5,sF5,args.alpha)
csF5 = csF5.to(next(wct.parameters()).device)
Im5 = wct.d5(csF5)

sF4 = wct.e4(styleImg)
cF4 = wct.e4(Im5)
sF4 = sF4.data.cpu().squeeze(0)
cF4 = cF4.data.cpu().squeeze(0)
csF4 = wct.transform(cF4,sF4,csF,args.alpha)
sF4 = wct.encoder(styleImg, 'relu4_1')
cF4 = wct.encoder(Im5, 'relu4_1')
sF4 = sF4.cpu().squeeze(0)
cF4 = cF4.cpu().squeeze(0)
csF4 = wct.transform(cF4,sF4,args.alpha)
csF4 = csF4.to(next(wct.parameters()).device)
Im4 = wct.d4(csF4)

sF3 = wct.e3(styleImg)
cF3 = wct.e3(Im4)
sF3 = sF3.data.cpu().squeeze(0)
cF3 = cF3.data.cpu().squeeze(0)
csF3 = wct.transform(cF3,sF3,csF,args.alpha)
sF3 = wct.encoder(styleImg, 'relu3_1')
cF3 = wct.encoder(Im4, 'relu3_1')
sF3 = sF3.cpu().squeeze(0)
cF3 = cF3.cpu().squeeze(0)
csF3 = wct.transform(cF3,sF3,args.alpha)
csF3 = csF3.to(next(wct.parameters()).device)
Im3 = wct.d3(csF3)

sF2 = wct.e2(styleImg)
cF2 = wct.e2(Im3)
sF2 = sF2.data.cpu().squeeze(0)
cF2 = cF2.data.cpu().squeeze(0)
csF2 = wct.transform(cF2,sF2,csF,args.alpha)
sF2 = wct.encoder(styleImg, 'relu2_1')
cF2 = wct.encoder(Im3, 'relu2_1')
sF2 = sF2.cpu().squeeze(0)
cF2 = cF2.cpu().squeeze(0)
csF2 = wct.transform(cF2,sF2,args.alpha)
csF2 = csF2.to(next(wct.parameters()).device)
Im2 = wct.d2(csF2)

sF1 = wct.e1(styleImg)
cF1 = wct.e1(Im2)
sF1 = sF1.data.cpu().squeeze(0)
cF1 = cF1.data.cpu().squeeze(0)
csF1 = wct.transform(cF1,sF1,csF,args.alpha)
sF1 = wct.encoder(styleImg, 'relu1_1')
cF1 = wct.encoder(Im2, 'relu1_1')
sF1 = sF1.cpu().squeeze(0)
cF1 = cF1.cpu().squeeze(0)
csF1 = wct.transform(cF1,sF1,args.alpha)
csF1 = csF1.to(next(wct.parameters()).device)
Im1 = wct.d1(csF1)
# save_image has this wired design to pad images with 4 pixels at default.
vutils.save_image(Im1.data.cpu().float(),os.path.join(args.outf,imname))
vutils.save_image(Im1.cpu().float(),os.path.join(args.outf,imname))
return

avgTime = 0
cImg = torch.Tensor()
sImg = torch.Tensor()
csF = torch.Tensor()
csF = Variable(csF)
if(args.cuda):
cImg = cImg.cuda(args.gpu)
sImg = sImg.cuda(args.gpu)
csF = csF.cuda(args.gpu)
wct.cuda(args.gpu)
for i,(contentImg,styleImg,imname) in enumerate(loader):
imname = imname[0]
print('Transferring ' + imname)
if (args.cuda):
contentImg = contentImg.cuda(args.gpu)
styleImg = styleImg.cuda(args.gpu)
cImg = Variable(contentImg,volatile=True)
sImg = Variable(styleImg,volatile=True)
start_time = time.time()
# WCT Style Transfer
styleTransfer(cImg,sImg,imname,csF)
end_time = time.time()
print('Elapsed time is: %f' % (end_time - start_time))
avgTime += (end_time - start_time)
def main():
wct = WCT(args)
if(args.cuda):
wct.cuda(args.gpu)

print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1)))
avgTime = 0
for i,(contentImg,styleImg,imname) in enumerate(loader):
if(args.cuda):
contentImg = contentImg.cuda(args.gpu)
styleImg = styleImg.cuda(args.gpu)
imname = imname[0]
print('Transferring ' + imname)
if (args.cuda):
contentImg = contentImg.cuda(args.gpu)
styleImg = styleImg.cuda(args.gpu)
start_time = time.time()
# WCT Style Transfer
styleTransfer(wct, contentImg, styleImg, imname)
end_time = time.time()
print('Elapsed time is: %f' % (end_time - start_time))
avgTime += (end_time - start_time)

print('Processed %d images. Averaged time is %f' % ((i+1),avgTime/(i+1)))

if __name__ == '__main__':
with torch.no_grad():
main()
133 changes: 0 additions & 133 deletions convertModels.py

This file was deleted.

Loading

0 comments on commit bb093b9

Please sign in to comment.